mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
Set FC's keep_num_dims to false when output dims is different from input dims after quantization.
On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum. The export steps are: 1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680]) 2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680]) 3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680]) 4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680]) When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape. PiperOrigin-RevId: 847813526
This commit is contained in:
committed by
TensorFlower Gardener
parent
9ca49fcfa5
commit
fec780d7fe
@@ -80,13 +80,13 @@ static LogicalResult IsDrqTensor(Value value, Value& fq_input) {
|
||||
// fake quant op.
|
||||
// This is to support the case such as:
|
||||
// %2077 = "vhlo.composite_v1"(%73, %69, %2070) : (tensor<i32>, tensor<i32>,
|
||||
// tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
|
||||
// tensor<1x?x512xf32>) -> tensor<1x?x512xf32>
|
||||
// %2078 = "tfl.reshape"(%2077, %99) : (tensor<1x?x512xf32>, tensor<2xi32>) ->
|
||||
// tensor<?x512xf32>
|
||||
// tensor<?x512xf32>
|
||||
// %2079 = "tfl.pseudo_qconst"() <{qtype = tensor<64x512x!quant.uniform<i8....
|
||||
// %2080 = "tfl.dequantize"(%2079) %2081 = "tfl.fully_connected"
|
||||
// (%2078, %2080, %0) : (tensor<?x512xf32>, tensor<64x512xf32>, none) ->
|
||||
// tensor<?x64xf32>
|
||||
// %2080 = "tfl.dequantize"(%2079)
|
||||
// %2081 = "tfl.fully_connected"(%2078, %2080, %0) : (tensor<?x512xf32>,
|
||||
// tensor<64x512xf32>, none) -> tensor<?x64xf32>
|
||||
// TODO - b/422588785: Have proper support for dynamic shaped models.
|
||||
auto v = value;
|
||||
if (auto reshape_op = llvm::dyn_cast_or_null<ReshapeOp>(v.getDefiningOp())) {
|
||||
@@ -228,6 +228,40 @@ class PushForwardDrqFQ : public OpRewritePattern<stablehlo::CompositeOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fixes keep_num_dims option of FC if output dims is different from input dims
|
||||
// though keep_num_dims is true. It happens when FC's input has changed after
|
||||
// quantization, e.g. by IsDrqTensor().
|
||||
// Sets keep_num_dims to false if that's the case. Otherwise, it's not
|
||||
// compatible with GPU. See CheckGpuDelegateCompatibility() in
|
||||
// third_party/tensorflow/lite/tools/versioning/gpu_compatibility.cc.
|
||||
// Note that if FC is followed by Reshape, the keep_num_dims will be set to true
|
||||
// with a correct shape later by EnableFullyConnectedKeepNumDimsBeforeReshape()
|
||||
// in optimize pass.
|
||||
struct FixFullyConnectedKeepNumDims
|
||||
: public OpRewritePattern<FullyConnectedOp> {
|
||||
explicit FixFullyConnectedKeepNumDims(MLIRContext* context)
|
||||
: OpRewritePattern<TFL::FullyConnectedOp>(context, /*benefit=*/0) {}
|
||||
|
||||
LogicalResult matchAndRewrite(FullyConnectedOp fc,
|
||||
PatternRewriter& rewriter) const override {
|
||||
if (!fc.getKeepNumDims()) return failure();
|
||||
|
||||
auto input_ty =
|
||||
mlir::dyn_cast_or_null<RankedTensorType>(fc.getInput().getType());
|
||||
auto fc_ty = mlir::dyn_cast_or_null<RankedTensorType>(fc.getType(0));
|
||||
if (!input_ty || !fc_ty) return failure();
|
||||
|
||||
auto input_shape = input_ty.getShape();
|
||||
auto fc_shape = fc_ty.getShape();
|
||||
if (input_shape.size() == fc_shape.size()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
fc.setKeepNumDims(false);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class StrictQuantizationPattern : public RewritePattern {
|
||||
public:
|
||||
using BaseType = StrictQuantizationPattern;
|
||||
@@ -764,7 +798,7 @@ void QuantizePass::runOnOperation() {
|
||||
patterns.add<TFLFullQuantization, TFLFullQuantizationReverse>(ctx,
|
||||
quant_params);
|
||||
}
|
||||
|
||||
patterns.add<FixFullyConnectedKeepNumDims>(ctx);
|
||||
(void)applyPatternsGreedily(func, std::move(patterns));
|
||||
|
||||
// Constant quantization is a lossy transformation, so they are applied only
|
||||
|
||||
Reference in New Issue
Block a user