mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[AMD] Block mem efficient attention for FP32 in CK backend (#151132)
Summary: CK doesn't support FP32 attention, but aotriton does. If we prefer CK, and the input dtype is FP32, we'll select mem efficient attention but CK doesn't support it. So we'll exclude mem eff attention and pick math. Differential Revision: D72880985 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151132 Approved by: https://github.com/yoyoyocmu
This commit is contained in:
committed by
PyTorch MergeBot
parent
71073caa00
commit
9d4de265db
@@ -730,6 +730,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
#ifdef USE_ROCM
|
||||
constexpr auto aotriton_mem_efficient_dtypes =
|
||||
c10::array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
constexpr auto ck_mem_efficient_dtypes =
|
||||
c10::array_of<at::ScalarType>(at::kHalf, at::kBFloat16);
|
||||
#else
|
||||
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
|
||||
c10::array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
@@ -789,6 +791,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) {
|
||||
return check_tensor_dtype(params, ck_mem_efficient_dtypes, debug);
|
||||
}
|
||||
return check_tensor_dtype(params, aotriton_mem_efficient_dtypes, debug);
|
||||
#else
|
||||
auto dprop = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
Reference in New Issue
Block a user