mirror of
https://github.com/zebrajr/tensorflow.git
synced 2026-01-15 12:15:41 +00:00
On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum. The export steps are: 1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680]) 2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680]) 3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680]) 4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680]) When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape. PiperOrigin-RevId: 847813526