Set FC's keep_num_dims to false when output dims is different from input dims after quantization.

On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum.
The export steps are:
1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680])
2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680])
3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680])
4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680])

When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape.

PiperOrigin-RevId: 847813526
This commit is contained in:
Byungchul Kim
2025-12-22 10:38:48 -08:00
committed by TensorFlower Gardener
parent 9ca49fcfa5
commit fec780d7fe

View File

@@ -80,13 +80,13 @@ static LogicalResult IsDrqTensor(Value value, Value& fq_input) {
// fake quant op.
// This is to support the case such as:
// %2077 = "vhlo.composite_v1"(%73, %69, %2070) : (tensor<i32>, tensor<i32>,
// tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
// tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
// %2078 = "tfl.reshape"(%2077, %99) : (tensor<1x?x512xf32>, tensor<2xi32>) ->
// tensor<?x512xf32>
// tensor<?x512xf32>
// %2079 = "tfl.pseudo_qconst"() <{qtype = tensor<64x512x!quant.uniform<i8....
// %2080 = "tfl.dequantize"(%2079) %2081 = "tfl.fully_connected"
// (%2078, %2080, %0) : (tensor<?x512xf32>, tensor<64x512xf32>, none) ->
// tensor<?x64xf32>
// %2080 = "tfl.dequantize"(%2079)
// %2081 = "tfl.fully_connected"(%2078, %2080, %0) : (tensor<?x512xf32>,
// tensor<64x512xf32>, none) -> tensor<?x64xf32>
// TODO - b/422588785: Have proper support for dynamic shaped models.
auto v = value;
if (auto reshape_op = llvm::dyn_cast_or_null<ReshapeOp>(v.getDefiningOp())) {
@@ -228,6 +228,40 @@ class PushForwardDrqFQ : public OpRewritePattern<stablehlo::CompositeOp> {
}
};
// Fixes keep_num_dims option of FC if output dims is different from input dims
// though keep_num_dims is true. It happens when FC's input has changed after
// quantization, e.g. by IsDrqTensor().
// Sets keep_num_dims to false if that's the case. Otherwise, it's not
// compatible with GPU. See CheckGpuDelegateCompatibility() in
// third_party/tensorflow/lite/tools/versioning/gpu_compatibility.cc.
// Note that if FC is followed by Reshape, the keep_num_dims will be set to true
// with a correct shape later by EnableFullyConnectedKeepNumDimsBeforeReshape()
// in optimize pass.
struct FixFullyConnectedKeepNumDims
: public OpRewritePattern<FullyConnectedOp> {
explicit FixFullyConnectedKeepNumDims(MLIRContext* context)
: OpRewritePattern<TFL::FullyConnectedOp>(context, /*benefit=*/0) {}
LogicalResult matchAndRewrite(FullyConnectedOp fc,
PatternRewriter& rewriter) const override {
if (!fc.getKeepNumDims()) return failure();
auto input_ty =
mlir::dyn_cast_or_null<RankedTensorType>(fc.getInput().getType());
auto fc_ty = mlir::dyn_cast_or_null<RankedTensorType>(fc.getType(0));
if (!input_ty || !fc_ty) return failure();
auto input_shape = input_ty.getShape();
auto fc_shape = fc_ty.getShape();
if (input_shape.size() == fc_shape.size()) {
return failure();
}
fc.setKeepNumDims(false);
return success();
}
};
class StrictQuantizationPattern : public RewritePattern {
public:
using BaseType = StrictQuantizationPattern;
@@ -764,7 +798,7 @@ void QuantizePass::runOnOperation() {
patterns.add<TFLFullQuantization, TFLFullQuantizationReverse>(ctx,
quant_params);
}
patterns.add<FixFullyConnectedKeepNumDims>(ctx);
(void)applyPatternsGreedily(func, std::move(patterns));
// Constant quantization is a lossy transformation, so they are applied only