diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 05acc275b46..3bd738a372d 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -730,6 +730,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { #ifdef USE_ROCM constexpr auto aotriton_mem_efficient_dtypes = c10::array_of(at::kHalf, at::kFloat, at::kBFloat16); + constexpr auto ck_mem_efficient_dtypes = + c10::array_of(at::kHalf, at::kBFloat16); #else constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes = c10::array_of(at::kHalf, at::kFloat, at::kBFloat16); @@ -789,6 +791,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { return false; } } + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return check_tensor_dtype(params, ck_mem_efficient_dtypes, debug); + } return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); #else auto dprop = at::cuda::getCurrentDeviceProperties();