[AMD] Block mem efficient attention for FP32 in CK backend (#151132)

Summary: CK doesn't support FP32 attention, but aotriton does. If we prefer CK, and the input dtype is FP32, we'll select mem efficient attention but CK doesn't support it. So we'll exclude mem eff attention and pick math.

Differential Revision: D72880985

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151132
Approved by: https://github.com/yoyoyocmu
This commit is contained in:
Xiaodong Wang
2025-04-12 19:36:17 +00:00
committed by PyTorch MergeBot
parent 71073caa00
commit 9d4de265db

View File

@@ -730,6 +730,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
#ifdef USE_ROCM
constexpr auto aotriton_mem_efficient_dtypes =
c10::array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
constexpr auto ck_mem_efficient_dtypes =
c10::array_of<at::ScalarType>(at::kHalf, at::kBFloat16);
#else
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
c10::array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
@@ -789,6 +791,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
return false;
}
}
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) {
return check_tensor_dtype(params, ck_mem_efficient_dtypes, debug);
}
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
#else
auto dprop = at::cuda::getCurrentDeviceProperties();