mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Update conv.cc
This commit is contained in:
@@ -749,11 +749,12 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
void EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
TfLiteTensor* input, TfLiteTensor* filter,
|
||||
TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConvParams* params, OpData* data,
|
||||
const TfLiteTensor* input,
|
||||
const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias,
|
||||
TfLiteTensor* im2col, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
@@ -917,8 +918,9 @@ TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node) {
|
||||
case kTfLiteFloat32:
|
||||
if (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8) {
|
||||
if (data->is_hybrid_per_channel) {
|
||||
EvalHybridPerChannel<kernel_type>(context, node, params, data, input,
|
||||
filter, bias, im2col, output);
|
||||
TF_LITE_ENSURE_OK(context, EvalHybridPerChannel<kernel_type>(
|
||||
context, node, params, data, input,
|
||||
filter, bias, im2col, output));
|
||||
} else {
|
||||
TfLiteTensor* accum_scratch =
|
||||
&context->tensors[node->temporaries
|
||||
|
||||
Reference in New Issue
Block a user