From 9d4de265db664c6a8cb4616701f54fffad1e04f0 Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Sat, 12 Apr 2025 19:36:17 +0000 Subject: [PATCH] [AMD] Block mem efficient attention for FP32 in CK backend (#151132) Summary: CK doesn't support FP32 attention, but aotriton does. If we prefer CK, and the input dtype is FP32, we'll select mem efficient attention but CK doesn't support it. So we'll exclude mem eff attention and pick math. Differential Revision: D72880985 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151132 Approved by: https://github.com/yoyoyocmu --- aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 05acc275b46..3bd738a372d 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -730,6 +730,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { #ifdef USE_ROCM constexpr auto aotriton_mem_efficient_dtypes = c10::array_of(at::kHalf, at::kFloat, at::kBFloat16); + constexpr auto ck_mem_efficient_dtypes = + c10::array_of(at::kHalf, at::kBFloat16); #else constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes = c10::array_of(at::kHalf, at::kFloat, at::kBFloat16); @@ -789,6 +791,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { return false; } } + if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { + return check_tensor_dtype(params, ck_mem_efficient_dtypes, debug); + } return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug); #else auto dprop = at::cuda::getCurrentDeviceProperties();