From 3c98eef883dbf497577043deedeeb22c30b4ff72 Mon Sep 17 00:00:00 2001 From: Jack Taylor Date: Tue, 30 Dec 2025 09:48:53 +0000 Subject: [PATCH] [ROCm] enable decompose k tests for functional coverage (#169948) Fixes https://github.com/pytorch/pytorch/issues/168617 Fixes https://github.com/pytorch/pytorch/issues/168615 Fixes https://github.com/pytorch/pytorch/issues/168614 Fixes https://github.com/pytorch/pytorch/issues/168613 Fixes https://github.com/pytorch/pytorch/issues/168599 Fixes https://github.com/pytorch/pytorch/issues/168600 Fixes https://github.com/pytorch/pytorch/issues/168601 Fixes https://github.com/pytorch/pytorch/issues/168602 Fixes https://github.com/pytorch/pytorch/issues/168603 Fixes https://github.com/pytorch/pytorch/issues/168604 Fixes https://github.com/pytorch/pytorch/issues/168605 Fixes https://github.com/pytorch/pytorch/issues/168606 Fixes https://github.com/pytorch/pytorch/issues/168607 Enables testing for decompose K mode on ROCm. This is still disabled by default pending perf testing but we can get the functional coverage by adding an inductor config for decompose k enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169948 Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/PaulZhang12 --- test/inductor/test_lookup_table.py | 35 +- test/inductor/test_max_autotune.py | 298 +++++++++--------- torch/_inductor/config.py | 5 +- .../template_heuristics/decompose_k.py | 8 +- torch/_inductor/utils.py | 3 +- 5 files changed, 182 insertions(+), 167 deletions(-) diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py index 250a8222678..3e69afe3599 100644 --- a/test/inductor/test_lookup_table.py +++ b/test/inductor/test_lookup_table.py @@ -23,12 +23,17 @@ from torch._inductor.virtualized import V from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device +# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere +_DECOMPOSE_K_PATCH_ROCM = ( + {"triton.num_decompose_k_splits": 10} if torch.version.hip else {} +) + + class MockTensorNode: """Mock input node that wraps a real tensor for testing""" @@ -833,7 +838,6 @@ class BaseE2ELookupTableTest(BaseLookupTableTest): ] -@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support lookup table") @unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available") @instantiate_parametrized_tests class TestLookupTableE2E(BaseE2ELookupTableTest): @@ -928,21 +932,26 @@ class TestLookupTableE2E(BaseE2ELookupTableTest): operation, tensors, {"triton.enable_persistent_tma_matmul": True} ) + # Enable decompose_k for this test (disabled by default on ROCm) @fresh_cache() def test_decompose_k_lookup_table_entry(self): """Test decompose_k template entry""" - tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32) - config = self.create_basic_config( - torch._inductor.kernel.mm.decompose_k_subgraph_template.uid - ) - - self.setup_lookup_table("mm", tensors, [config]) - add_preprocessing_fn( - partial( - verify_choice_names, pattern="decompose_k|bmm_dtype", expected_count=1 + with inductor_config.patch(_DECOMPOSE_K_PATCH_ROCM): + tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32) + config = self.create_basic_config( + torch._inductor.kernel.mm.decompose_k_subgraph_template.uid ) - ) - self.run_model("mm", tensors) + + self.setup_lookup_table("mm", tensors, [config]) + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="decompose_k|bmm_dtype", + expected_count=1, + ) + ) + + self.run_model("mm", tensors) @fresh_cache() def test_bias_addmm_lookup_table_entry(self): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index bd4d44301de..f9ffd391ae7 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -105,6 +105,11 @@ else: if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") +# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere +_DECOMPOSE_K_PATCH_ROCM = ( + {"triton.num_decompose_k_splits": 10} if torch.version.hip else {} +) + def benchmark_choice(choice, args, out, expected_out, timings): result = choice.benchmark(*args, out=out) @@ -1214,101 +1219,100 @@ class TestMaxAutotune(TestCase): shape_padding=False, ) def test_max_autotune_decompose_k(self, sizes, dtype, dynamic): - fp16_red_setting = ( - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction - ) - bf16_red_setting = ( - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction - ) - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False + with config.patch(_DECOMPOSE_K_PATCH_ROCM): + fp16_red_setting = ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + ) + bf16_red_setting = ( + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False - M, N, K = sizes + M, N, K = sizes - a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) - b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) + a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True) + b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True) - possible_splits = range(2, min(K // M, K // N) + 1) + possible_splits = range(2, min(K // M, K // N) + 1) - divisors = {split for split in possible_splits if K % split == 0} + divisors = {split for split in possible_splits if K % split == 0} - def check_divisors(code): - for kernel in code: - if "decompose_k" in kernel: - divisor_found = False - for divisor in divisors: - if f"{divisor}_split" in kernel: - divisor_found = True - break + def check_divisors(code): + for kernel in code: + if "decompose_k" in kernel: + divisor_found = False + for divisor in divisors: + if f"{divisor}_split" in kernel: + divisor_found = True + break - self.assertTrue( - divisor_found, - f"Could not find a split in {divisors} in {kernel}", - ) + self.assertTrue( + divisor_found, + f"Could not find a split in {divisors} in {kernel}", + ) - compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic) - # We assume with the large k dim relative to m, n, decompose_k will be most performant - out, code = run_and_get_code(compiled_func, a, b) + compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic) + # We assume with the large k dim relative to m, n, decompose_k will be most performant + out, code = run_and_get_code(compiled_func, a, b) - if dynamic or torch.version.hip: - FileCheck().check_not("extern_kernels.bmm_dtype").check_not( - "decompose_k" - ).run(code[0]) - else: - FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_0.run" - ).check("decompose_k").run(code[0]) - check_divisors(code) - torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2) + if dynamic: + FileCheck().check_not("extern_kernels.bmm_dtype").check_not( + "decompose_k" + ).run(code[0]) + else: + FileCheck().check("extern_kernels.bmm_dtype").check_regex( + "triton_.*_fused_0.run" + ).check("decompose_k").run(code[0]) + check_divisors(code) + torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2) - # Test adding epilogue also equivalent to eager - compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic) - out, code = run_and_get_code(compiled_func, a, b) - if dynamic or torch.version.hip: - FileCheck().check_not("extern_kernels.bmm_dtype").check_not( - "decompose_k" - ).run(code[0]) - else: - FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_mm_0.run" - ).check("decompose_k").run(code[0]) - check_divisors(code) - torch.testing.assert_close( - compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2 + # Test adding epilogue also equivalent to eager + compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic) + out, code = run_and_get_code(compiled_func, a, b) + if dynamic: + FileCheck().check_not("extern_kernels.bmm_dtype").check_not( + "decompose_k" + ).run(code[0]) + else: + FileCheck().check("extern_kernels.bmm_dtype").check_regex( + "triton_.*_fused_mm_0.run" + ).check("decompose_k").run(code[0]) + check_divisors(code) + torch.testing.assert_close( + compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2 + ) + + # Test adding reinterpret view before subgraph + a = a.transpose(0, 1) + compiled_func = torch.compile( + lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic + ) + out, code = run_and_get_code(compiled_func, a, b) + + if dynamic: + FileCheck().check_not("extern_kernels.bmm_dtype").check_not( + "decompose_k" + ).run(code[0]) + else: + FileCheck().check("extern_kernels.bmm_dtype").check_regex( + "triton_.*_fused_.*_0.run" + ).check("decompose_k").run(code[0]) + check_divisors(code) + torch.testing.assert_close( + compiled_func(a, b), + (a.transpose(0, 1) @ b).relu(), + atol=1e-2, + rtol=1e-2, + ) + + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + fp16_red_setting + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + bf16_red_setting ) - # Test adding reinterpret view before subgraph - a = a.transpose(0, 1) - compiled_func = torch.compile( - lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic - ) - out, code = run_and_get_code(compiled_func, a, b) - - # DecomposeK is not enabled for AMD yet - if dynamic or torch.version.hip: - FileCheck().check_not("extern_kernels.bmm_dtype").check_not( - "decompose_k" - ).run(code[0]) - else: - FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_.*_0.run" - ).check("decompose_k").run(code[0]) - check_divisors(code) - torch.testing.assert_close( - compiled_func(a, b), - (a.transpose(0, 1) @ b).relu(), - atol=1e-2, - rtol=1e-2, - ) - - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( - fp16_red_setting - ) - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( - bf16_red_setting - ) - - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1356,7 +1360,6 @@ class TestMaxAutotune(TestCase): rtol=1e-2, ) - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1369,44 +1372,45 @@ class TestMaxAutotune(TestCase): max_autotune_gemm_backends="TRITON", ) def test_max_autotune_decompose_k_dynamic_input_bwd(self): - def f(a, b): - # 256 * s0 - a_in = torch.cat([a for _ in range(256)], dim=0) - return (a_in @ b).relu().sum() + with config.patch(_DECOMPOSE_K_PATCH_ROCM): - a = torch.randn( - 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True - ) - b = torch.randn( - 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True - ) + def f(a, b): + # 256 * s0 + a_in = torch.cat([a for _ in range(256)], dim=0) + return (a_in @ b).relu().sum() - torch._dynamo.reset() - torch._dynamo.maybe_mark_dynamic(a, 0) - compiled_func = torch.compile(f) - res = compiled_func(a, b) - res.backward() - - with mock.patch( - "torch._inductor.kernel.mm.use_decompose_k_choice" - ) as decomp_mock: - decomp_mock.side_effect = ( - lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1 + a = torch.randn( + 8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True + ) + b = torch.randn( + 64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True ) - out, code = run_and_get_code(compiled_func, a, b) - out.backward() + torch._dynamo.reset() + torch._dynamo.maybe_mark_dynamic(a, 0) + compiled_func = torch.compile(f) + res = compiled_func(a, b) + res.backward() - FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_0.run" - ).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex( - r"256\*s[0-9]+" - ).check_regex("s[0-9]+ = 8").run( - # code[1] in this case given backwards - code[1] - ) + with mock.patch( + "torch._inductor.kernel.mm.use_decompose_k_choice" + ) as decomp_mock: + decomp_mock.side_effect = ( + lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1 + ) + + out, code = run_and_get_code(compiled_func, a, b) + out.backward() + + FileCheck().check("extern_kernels.bmm_dtype").check_regex( + "triton_.*_fused_0.run" + ).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex( + r"256\*s[0-9]+" + ).check_regex("s[0-9]+ = 8").run( + # code[1] in this case given backwards + code[1] + ) - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1419,38 +1423,42 @@ class TestMaxAutotune(TestCase): max_autotune_gemm_backends="TRITON", ) def test_max_autotune_decompose_k_output_stride(self): - def f(a, b): - a = a.transpose(0, 1) - return a @ b + with config.patch(_DECOMPOSE_K_PATCH_ROCM): - a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) - b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) + def f(a, b): + a = a.transpose(0, 1) + return a @ b - b = b[:, :1096] + a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16) - # Force only decomposeK choice - with ( - override_template_heuristics( - device_type=GPU_TYPE, - template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")], - ), - mock.patch( - "torch._inductor.kernel.mm.use_decompose_k_choice" - ) as decompose_mock, - ): - decompose_mock.return_value = True - compiled_f = torch.compile(f) - out, code = run_and_get_code(compiled_f, a, b) + b = b[:, :1096] - # Output stride equal to original gm output stride - # If output stride is not correctly checked, this will be (1152, 1) which can cause nans - self.assertEqual(out.stride(), (1096, 1)) + # Force only decomposeK choice + with ( + override_template_heuristics( + device_type=GPU_TYPE, + template_op_pairs=[ + (torch._inductor.kernel.mm.mm_template.name, "mm") + ], + ), + mock.patch( + "torch._inductor.kernel.mm.use_decompose_k_choice" + ) as decompose_mock, + ): + decompose_mock.return_value = True + compiled_f = torch.compile(f) + out, code = run_and_get_code(compiled_f, a, b) - FileCheck().check_not("extern_kernels.bmm_dtype").check( - "decompose_k" - ).check( - f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)" - ).run(code[0]) + # Output stride equal to original gm output stride + # If output stride is not correctly checked, this will be (1152, 1) which can cause nans + self.assertEqual(out.stride(), (1096, 1)) + + FileCheck().check_not("extern_kernels.bmm_dtype").check( + "decompose_k" + ).check( + f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)" + ).run(code[0]) @unittest.skipIf(not torch.version.hip, "ROCM only") @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) @@ -2056,7 +2064,7 @@ class TestMaxAutotune(TestCase): self.assertEqual(misses(), 4) @fresh_cache() - @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") + @skipIfXpu @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ee4124df8e9..5f1f05924e9 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1718,8 +1718,11 @@ class triton: disallow_failing_autotune_kernels_TESTING_ONLY = False # specify number of splits to autotune on for decompose_k. 0 disables decompose_k + # Disabled on ROCm by default pending performance validation. num_decompose_k_splits = int( - os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") + os.environ.get( + "TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "0" if torch.version.hip else "10" + ) ) # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables diff --git a/torch/_inductor/template_heuristics/decompose_k.py b/torch/_inductor/template_heuristics/decompose_k.py index 18bcbe6aa45..8cabcbfea08 100644 --- a/torch/_inductor/template_heuristics/decompose_k.py +++ b/torch/_inductor/template_heuristics/decompose_k.py @@ -4,8 +4,6 @@ from typing import Any, TYPE_CHECKING import sympy -import torch - from ..ir import get_free_symbols from ..kernel.mm import decompose_k_subgraph_template from ..kernel_inputs import KernelInputs, MMKernelInputs @@ -30,15 +28,13 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics): "xpu", op_name="mm", ) -# on CUDA, we don't support hip for decompose_k yet +# Register on CUDA (both NVIDIA and ROCm/HIP) +# Runtime enablement is controlled by config.triton.num_decompose_k_splits (0 disables) @register_template_heuristic( decompose_k_subgraph_template.uid, "cuda", - register=torch.version.hip is None, op_name="mm", ) -# TODO(coconutruben): enable decompose k on AMD by removing the register bool -# and benchmarking it for performance and stability # TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia) # by either adding specific register_template_heuristic tags, or setting the # device to None (enabled on all devices) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 40f0dad91f5..f4d501b00cd 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2169,8 +2169,7 @@ def use_decompose_k_choice( decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple return ( - not torch.version.hip - and V.graph.sizevars.statically_known_true( + V.graph.sizevars.statically_known_true( sympy.And( sympy.Ge(k, decompose_k_threshold * m), sympy.Ge(k, decompose_k_threshold * n),