[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)

This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734
Approved by: https://github.com/eqy, https://github.com/cyyever
This commit is contained in:
Aidyn-A
2025-11-18 08:15:46 +00:00
committed by PyTorch MergeBot
parent 9760a633ba
commit 2f023bf7b9

View File

@@ -118,6 +118,12 @@ if(INTERN_BUILD_ATEN_OPS)
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
endif()
endif()
# We will need to gate against CUDA version, sm_121a was introduced in CUDA 12.9
if("${_arch}" STREQUAL "121a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
if(_existing_arch_flags MATCHES ".*compute_120.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
endif()
endif()
endforeach()
list(JOIN _file_compile_flags " " _file_compile_flags)
@@ -126,7 +132,7 @@ if(INTERN_BUILD_ATEN_OPS)
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a;103a;120a")
"89;90a;100a;103a;120a;121a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"90a")