Files
tensorflow/tensorflow/compiler
Byungchul Kim fec780d7fe Set FC's keep_num_dims to false when output dims is different from input dims after quantization.
On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum.
The export steps are:
1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680])
2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680])
3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680])
4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680])

When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape.

PiperOrigin-RevId: 847813526
2025-12-22 10:45:22 -08:00
..
2025-12-22 00:40:30 -08:00
2025-12-18 16:12:37 -08:00
2025-11-15 09:06:24 -08:00