[pallas backend] Require sm90+ for mosaic (#171531)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/171531
Approved by: https://github.com/choijon5
ghstack dependencies: #171475, #171485
This commit is contained in:
Oguz Ulgen
2025-12-30 13:53:31 -08:00
committed by PyTorch MergeBot
parent 2fff179ded
commit 98bc4d77e1

View File

@@ -44,7 +44,7 @@ def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, in
@functools.cache
def has_jax_cuda_backend() -> bool:
"""Check if JAX has CUDA backend support."""
"""Check if JAX has CUDA backend support with SM90+ (required by Mosaic GPU)."""
if not has_jax_package():
return False
try:
@@ -52,7 +52,16 @@ def has_jax_cuda_backend() -> bool:
# Check if CUDA backend is available
devices = jax.devices("gpu")
return len(devices) > 0
if len(devices) == 0:
return False
# Mosaic GPU requires SM90+ (compute capability 9.0+)
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major < 9:
return False
return True
except Exception:
return False