diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index c213c1ee498..c50e0a26e71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -80,13 +80,13 @@ static LogicalResult IsDrqTensor(Value value, Value& fq_input) { // fake quant op. // This is to support the case such as: // %2077 = "vhlo.composite_v1"(%73, %69, %2070) : (tensor, tensor, - // tensor<1x?x512xf32>) -> tensor<1x?x512xf32> + // tensor<1x?x512xf32>) -> tensor<1x?x512xf32> // %2078 = "tfl.reshape"(%2077, %99) : (tensor<1x?x512xf32>, tensor<2xi32>) -> - // tensor + // tensor // %2079 = "tfl.pseudo_qconst"() <{qtype = tensor<64x512x!quant.uniform, tensor<64x512xf32>, none) -> - // tensor + // %2080 = "tfl.dequantize"(%2079) + // %2081 = "tfl.fully_connected"(%2078, %2080, %0) : (tensor, + // tensor<64x512xf32>, none) -> tensor // TODO - b/422588785: Have proper support for dynamic shaped models. auto v = value; if (auto reshape_op = llvm::dyn_cast_or_null(v.getDefiningOp())) { @@ -228,6 +228,40 @@ class PushForwardDrqFQ : public OpRewritePattern { } }; +// Fixes keep_num_dims option of FC if output dims is different from input dims +// though keep_num_dims is true. It happens when FC's input has changed after +// quantization, e.g. by IsDrqTensor(). +// Sets keep_num_dims to false if that's the case. Otherwise, it's not +// compatible with GPU. See CheckGpuDelegateCompatibility() in +// third_party/tensorflow/lite/tools/versioning/gpu_compatibility.cc. +// Note that if FC is followed by Reshape, the keep_num_dims will be set to true +// with a correct shape later by EnableFullyConnectedKeepNumDimsBeforeReshape() +// in optimize pass. +struct FixFullyConnectedKeepNumDims + : public OpRewritePattern { + explicit FixFullyConnectedKeepNumDims(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(FullyConnectedOp fc, + PatternRewriter& rewriter) const override { + if (!fc.getKeepNumDims()) return failure(); + + auto input_ty = + mlir::dyn_cast_or_null(fc.getInput().getType()); + auto fc_ty = mlir::dyn_cast_or_null(fc.getType(0)); + if (!input_ty || !fc_ty) return failure(); + + auto input_shape = input_ty.getShape(); + auto fc_shape = fc_ty.getShape(); + if (input_shape.size() == fc_shape.size()) { + return failure(); + } + + fc.setKeepNumDims(false); + return success(); + } +}; + class StrictQuantizationPattern : public RewritePattern { public: using BaseType = StrictQuantizationPattern; @@ -764,7 +798,7 @@ void QuantizePass::runOnOperation() { patterns.add(ctx, quant_params); } - + patterns.add(ctx); (void)applyPatternsGreedily(func, std::move(patterns)); // Constant quantization is a lossy transformation, so they are applied only