Sometimes it is useful to turn an XlaComputation straight into a HloModule in a
test. This is already functionality we basically support, but until now the
computation had to be in the form of an XlaBuilder, which is not always
practical.
PiperOrigin-RevId: 847856677
On gemma3n with decode batch > 1, it happens when the embedding is coupled with PLE by einsum.
The export steps are:
1) Initial: BMM([b,2048]x[2048,7680] -> [b,7680])
2) FuseInputReshape_BatchMatMulWithFlattenedRhsDims: BMM([b,2048]x[2048,7680] -> [b,7680])
3) ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs: FC([b,2048]x[2048,7680] -> [b,7680])
4) StrictQuantizationPattern(by IsDrqTensor): FC([b,1,2048]x[2048,7680] -> [b,7680])
When FC's keep_num_dims is false and it's followed by reshape op (like gemma3n), keep_num_dims will be set to true later with correct shapes by EnableFullyConnectedKeepNumDimsBeforeReshape.
PiperOrigin-RevId: 847813526
This change for the new autotuner. The new autotuner with its Triton backend competes with cuDNN fusions leading to flaky tests. Also some tests disable some autotuning paths via --xla_gpu_cudnn_gemm_fusion_level or --xla_gpu_cublas_fallback which are not fully compatible with the new autotuner. Other tests rely on the order of the backends, which would be resolved by adding a backend selection mechanism.
PiperOrigin-RevId: 847750954
The tests in xla/backends/gpu/codegen/triton/BUILD are already configured to run only on specific GPU backends, making the if_gpu_is_configured check on the srcs redundant.
PiperOrigin-RevId: 847738574
This CL refactors the XLA profiler's state-checking mechanism to resolve GIL deadlocks and improve performance.
Previously, the C++ profiler context would import a Python module to update the profiler's state. This operation, performed while holding the GIL, could cause deadlocks if the import failed (e.g., in a JAX-only environment).
This change replaces the fragile cross-language import with a shared C++ std::atomic<bool>. Python code now queries this state via a new, low-overhead C function (is_traceme_enabled_raw) instead of ctypes.
This approach eliminates the deadlocks, decouples the C++ profiler from Python modules, and maintains high performance for the state check. The internal C++ API was also updated to use a safer reference instead of a raw pointer.
PiperOrigin-RevId: 847261952