Update conv.cc

This commit is contained in:
geetachavan1
2021-06-04 13:18:44 -07:00
committed by GitHub
parent 35b9cd4004
commit 93edcb40a1

View File

@@ -749,11 +749,12 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
TfLiteTensor* input, TfLiteTensor* filter,
TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* output) {
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
TfLiteConvParams* params, OpData* data,
const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias,
TfLiteTensor* im2col, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
@@ -917,8 +918,9 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteFloat32:
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
if (data->is_hybrid_per_channel) {
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
filter, bias, im2col, output);
TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
context, node, params, data, input,
filter, bias, im2col, output));
} else {
TfLiteTensor* accum_scratch =
&context->tensors[node->temporaries