mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[pallas backend] Require sm90+ for mosaic (#171531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/171531 Approved by: https://github.com/choijon5 ghstack dependencies: #171475, #171485
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fff179ded
commit
98bc4d77e1
@@ -44,7 +44,7 @@ def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, in
|
||||
|
||||
@functools.cache
|
||||
def has_jax_cuda_backend() -> bool:
|
||||
"""Check if JAX has CUDA backend support."""
|
||||
"""Check if JAX has CUDA backend support with SM90+ (required by Mosaic GPU)."""
|
||||
if not has_jax_package():
|
||||
return False
|
||||
try:
|
||||
@@ -52,7 +52,16 @@ def has_jax_cuda_backend() -> bool:
|
||||
|
||||
# Check if CUDA backend is available
|
||||
devices = jax.devices("gpu")
|
||||
return len(devices) > 0
|
||||
if len(devices) == 0:
|
||||
return False
|
||||
|
||||
# Mosaic GPU requires SM90+ (compute capability 9.0+)
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 9:
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user