diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 89f4907d45a..1c38cecfe17 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -84,7 +84,8 @@ bool check_prefer_cudnn_attention() { try { auto dprops = at::cuda::getCurrentDeviceProperties(); auto major = dprops->major; - return (major == 9 || major == 10) && !dprops->minor; + auto minor = dprops->minor; + return (major == 9 || major == 10) && (!minor || minor == 3); } catch ([[maybe_unused]] c10::Error const& e) { #ifdef DEBUG TORCH_WARN("check_prefer_cudnn_attention() caught exception ", e.what());