mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[ATen][CUDA] Add sm_121a flag for RowwiseScaledMM (#167734)
This PR add a sm_121a flag for row-wise scaled matmuls on DGX Spark. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167734 Approved by: https://github.com/eqy, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
9760a633ba
commit
2f023bf7b9
@@ -118,6 +118,12 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
endif()
|
||||
endif()
|
||||
# We will need to gate against CUDA version, sm_121a was introduced in CUDA 12.9
|
||||
if("${_arch}" STREQUAL "121a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_121a,code=sm_121a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
@@ -126,7 +132,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;103a;120a")
|
||||
"89;90a;100a;103a;120a;121a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
|
||||
Reference in New Issue
Block a user