From 1ea9cde598ead20194dbb6c5cb26e74e36e6ad55 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Mon, 14 Jul 2025 02:50:36 +0000 Subject: [PATCH] [ROCm] logsumexp on ROCm needs scaling back to natural base. (#156903) Fixes #156012 This is a temporary solution that makes context parallelism working before logsumexp behavior changes landed in AOTriton. After discussion we are not going to release AOTriton 0.10.1 to fix this due to * Even if the interface is not changed, changing the behavior of returned logsumexp tensor should still be considered as an ABI break. Such changes do not fall into the "ABI compatible" category and should be postponed to next release. * AOTriton 0.11 is scheduled to be released before end of July, which is less than five weeks Pull Request resolved: https://github.com/pytorch/pytorch/pull/156903 Approved by: https://github.com/jeffdaily, https://github.com/XilunWu --- .../tensor/experimental/_attention.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 73b53f05142..b3a5768f6fc 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -43,6 +43,16 @@ class _RotateMethod(Enum): aten = torch.ops.aten logger = logging.getLogger(__name__) +_is_hip: bool = hasattr(torch.version, "hip") and torch.version.hip is not None +if _is_hip: + gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName + _is_ck_supported = False + for arch in ["gfx942", "gfx950"]: + if arch in gcn_arch_name: + _is_ck_supported = True + _preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library + _CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"] + class _DispatchMode(Enum): MONKEY_PATCH = auto() @@ -446,6 +456,14 @@ def _templated_ring_attention( is_causal=is_causal_behavior.value, **kwargs, ) + if _is_hip: # See: https://github.com/pytorch/pytorch/issues/156012 + need_scaling = True + # Note: it is possible that CK is seleted but not compiled in the binary. + if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND: + # Unsure about CK's behavior, keep logsumexp untouched + need_scaling = False + if need_scaling: + logsumexp *= 0.6931471805599453 sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest