From fec780d7febddc82045efa6923ed00040437de9e Mon Sep 17 00:00:00 2001 From: Byungchul Kim Date: Mon, 22 Dec 2025 10:38:48 -0800 Subject: [PATCH] Set FC's keep_num_dims to false when output dims is different from input dims after quantization. On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum. The export steps are: 1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680]) 2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680]) 3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680]) 4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680]) When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape. PiperOrigin-RevId: 847813526 --- .../compiler/mlir/lite/transforms/quantize.cc | 46 ++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index c213c1ee498..c50e0a26e71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -80,13 +80,13 @@ static LogicalResult IsDrqTensor(Value value, Value& fq_input) { // fake quant op. // This is to support the case such as: // %2077 = "vhlo.composite_v1"(%73, %69, %2070) : (tensor, tensor, - // tensor<1x?x512xf32>) -> tensor<1x?x512xf32> + // tensor<1x?x512xf32>) -> tensor<1x?x512xf32> // %2078 = "tfl.reshape"(%2077, %99) : (tensor<1x?x512xf32>, tensor<2xi32>) -> - // tensor + // tensor // %2079 = "tfl.pseudo_qconst"() <{qtype = tensor<64x512x!quant.uniform, tensor<64x512xf32>, none) -> - // tensor + // %2080 = "tfl.dequantize"(%2079) + // %2081 = "tfl.fully_connected"(%2078, %2080, %0) : (tensor, + // tensor<64x512xf32>, none) -> tensor // TODO - b/422588785: Have proper support for dynamic shaped models. auto v = value; if (auto reshape_op = llvm::dyn_cast_or_null(v.getDefiningOp())) { @@ -228,6 +228,40 @@ class PushForwardDrqFQ : public OpRewritePattern { } }; +// Fixes keep_num_dims option of FC if output dims is different from input dims +// though keep_num_dims is true. It happens when FC's input has changed after +// quantization, e.g. by IsDrqTensor(). +// Sets keep_num_dims to false if that's the case. Otherwise, it's not +// compatible with GPU. See CheckGpuDelegateCompatibility() in +// third_party/tensorflow/lite/tools/versioning/gpu_compatibility.cc. +// Note that if FC is followed by Reshape, the keep_num_dims will be set to true +// with a correct shape later by EnableFullyConnectedKeepNumDimsBeforeReshape() +// in optimize pass. +struct FixFullyConnectedKeepNumDims + : public OpRewritePattern { + explicit FixFullyConnectedKeepNumDims(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(FullyConnectedOp fc, + PatternRewriter& rewriter) const override { + if (!fc.getKeepNumDims()) return failure(); + + auto input_ty = + mlir::dyn_cast_or_null(fc.getInput().getType()); + auto fc_ty = mlir::dyn_cast_or_null(fc.getType(0)); + if (!input_ty || !fc_ty) return failure(); + + auto input_shape = input_ty.getShape(); + auto fc_shape = fc_ty.getShape(); + if (input_shape.size() == fc_shape.size()) { + return failure(); + } + + fc.setKeepNumDims(false); + return success(); + } +}; + class StrictQuantizationPattern : public RewritePattern { public: using BaseType = StrictQuantizationPattern; @@ -764,7 +798,7 @@ void QuantizePass::runOnOperation() { patterns.add(ctx, quant_params); } - + patterns.add(ctx); (void)applyPatternsGreedily(func, std::move(patterns)); // Constant quantization is a lossy transformation, so they are applied only