Jack Taylor
2025-12-30 09:48:53 +00:00
committed by PyTorch MergeBot
parent 9ade6aad80
commit 3c98eef883
5 changed files with 182 additions and 167 deletions

View File

@@ -23,12 +23,17 @@ from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
TEST_WITH_ROCM,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere
_DECOMPOSE_K_PATCH_ROCM = (
{"triton.num_decompose_k_splits": 10} if torch.version.hip else {}
)
class MockTensorNode:
"""Mock input node that wraps a real tensor for testing"""
@@ -833,7 +838,6 @@ class BaseE2ELookupTableTest(BaseLookupTableTest):
]
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support lookup table")
@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available")
@instantiate_parametrized_tests
class TestLookupTableE2E(BaseE2ELookupTableTest):
@@ -928,21 +932,26 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
operation, tensors, {"triton.enable_persistent_tma_matmul": True}
)
# Enable decompose_k for this test (disabled by default on ROCm)
@fresh_cache()
def test_decompose_k_lookup_table_entry(self):
"""Test decompose_k template entry"""
tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32)
config = self.create_basic_config(
torch._inductor.kernel.mm.decompose_k_subgraph_template.uid
)
self.setup_lookup_table("mm", tensors, [config])
add_preprocessing_fn(
partial(
verify_choice_names, pattern="decompose_k|bmm_dtype", expected_count=1
with inductor_config.patch(_DECOMPOSE_K_PATCH_ROCM):
tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32)
config = self.create_basic_config(
torch._inductor.kernel.mm.decompose_k_subgraph_template.uid
)
)
self.run_model("mm", tensors)
self.setup_lookup_table("mm", tensors, [config])
add_preprocessing_fn(
partial(
verify_choice_names,
pattern="decompose_k|bmm_dtype",
expected_count=1,
)
)
self.run_model("mm", tensors)
@fresh_cache()
def test_bias_addmm_lookup_table_entry(self):

View File

@@ -105,6 +105,11 @@ else:
if HAS_CUDA_AND_TRITON:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere
_DECOMPOSE_K_PATCH_ROCM = (
{"triton.num_decompose_k_splits": 10} if torch.version.hip else {}
)
def benchmark_choice(choice, args, out, expected_out, timings):
result = choice.benchmark(*args, out=out)
@@ -1214,101 +1219,100 @@ class TestMaxAutotune(TestCase):
shape_padding=False,
)
def test_max_autotune_decompose_k(self, sizes, dtype, dynamic):
fp16_red_setting = (
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
)
bf16_red_setting = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
fp16_red_setting = (
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
)
bf16_red_setting = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
M, N, K = sizes
M, N, K = sizes
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
possible_splits = range(2, min(K // M, K // N) + 1)
possible_splits = range(2, min(K // M, K // N) + 1)
divisors = {split for split in possible_splits if K % split == 0}
divisors = {split for split in possible_splits if K % split == 0}
def check_divisors(code):
for kernel in code:
if "decompose_k" in kernel:
divisor_found = False
for divisor in divisors:
if f"{divisor}_split" in kernel:
divisor_found = True
break
def check_divisors(code):
for kernel in code:
if "decompose_k" in kernel:
divisor_found = False
for divisor in divisors:
if f"{divisor}_split" in kernel:
divisor_found = True
break
self.assertTrue(
divisor_found,
f"Could not find a split in {divisors} in {kernel}",
)
self.assertTrue(
divisor_found,
f"Could not find a split in {divisors} in {kernel}",
)
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
# We assume with the large k dim relative to m, n, decompose_k will be most performant
out, code = run_and_get_code(compiled_func, a, b)
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
# We assume with the large k dim relative to m, n, decompose_k will be most performant
out, code = run_and_get_code(compiled_func, a, b)
if dynamic or torch.version.hip:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
# Test adding epilogue also equivalent to eager
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
out, code = run_and_get_code(compiled_func, a, b)
if dynamic or torch.version.hip:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_mm_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
# Test adding epilogue also equivalent to eager
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
out, code = run_and_get_code(compiled_func, a, b)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_mm_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
)
# Test adding reinterpret view before subgraph
a = a.transpose(0, 1)
compiled_func = torch.compile(
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
)
out, code = run_and_get_code(compiled_func, a, b)
if dynamic:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_.*_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b),
(a.transpose(0, 1) @ b).relu(),
atol=1e-2,
rtol=1e-2,
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
fp16_red_setting
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
bf16_red_setting
)
# Test adding reinterpret view before subgraph
a = a.transpose(0, 1)
compiled_func = torch.compile(
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
)
out, code = run_and_get_code(compiled_func, a, b)
# DecomposeK is not enabled for AMD yet
if dynamic or torch.version.hip:
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
"decompose_k"
).run(code[0])
else:
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_.*_0.run"
).check("decompose_k").run(code[0])
check_divisors(code)
torch.testing.assert_close(
compiled_func(a, b),
(a.transpose(0, 1) @ b).relu(),
atol=1e-2,
rtol=1e-2,
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
fp16_red_setting
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
bf16_red_setting
)
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@@ -1356,7 +1360,6 @@ class TestMaxAutotune(TestCase):
rtol=1e-2,
)
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@@ -1369,44 +1372,45 @@ class TestMaxAutotune(TestCase):
max_autotune_gemm_backends="TRITON",
)
def test_max_autotune_decompose_k_dynamic_input_bwd(self):
def f(a, b):
# 256 * s0
a_in = torch.cat([a for _ in range(256)], dim=0)
return (a_in @ b).relu().sum()
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
a = torch.randn(
8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
b = torch.randn(
64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
def f(a, b):
# 256 * s0
a_in = torch.cat([a for _ in range(256)], dim=0)
return (a_in @ b).relu().sum()
torch._dynamo.reset()
torch._dynamo.maybe_mark_dynamic(a, 0)
compiled_func = torch.compile(f)
res = compiled_func(a, b)
res.backward()
with mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decomp_mock:
decomp_mock.side_effect = (
lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1
a = torch.randn(
8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
b = torch.randn(
64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
out, code = run_and_get_code(compiled_func, a, b)
out.backward()
torch._dynamo.reset()
torch._dynamo.maybe_mark_dynamic(a, 0)
compiled_func = torch.compile(f)
res = compiled_func(a, b)
res.backward()
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
r"256\*s[0-9]+"
).check_regex("s[0-9]+ = 8").run(
# code[1] in this case given backwards
code[1]
)
with mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decomp_mock:
decomp_mock.side_effect = (
lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1
)
out, code = run_and_get_code(compiled_func, a, b)
out.backward()
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
"triton_.*_fused_0.run"
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
r"256\*s[0-9]+"
).check_regex("s[0-9]+ = 8").run(
# code[1] in this case given backwards
code[1]
)
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)
@@ -1419,38 +1423,42 @@ class TestMaxAutotune(TestCase):
max_autotune_gemm_backends="TRITON",
)
def test_max_autotune_decompose_k_output_stride(self):
def f(a, b):
a = a.transpose(0, 1)
return a @ b
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16)
b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16)
def f(a, b):
a = a.transpose(0, 1)
return a @ b
b = b[:, :1096]
a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16)
b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16)
# Force only decomposeK choice
with (
override_template_heuristics(
device_type=GPU_TYPE,
template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")],
),
mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decompose_mock,
):
decompose_mock.return_value = True
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
b = b[:, :1096]
# Output stride equal to original gm output stride
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
self.assertEqual(out.stride(), (1096, 1))
# Force only decomposeK choice
with (
override_template_heuristics(
device_type=GPU_TYPE,
template_op_pairs=[
(torch._inductor.kernel.mm.mm_template.name, "mm")
],
),
mock.patch(
"torch._inductor.kernel.mm.use_decompose_k_choice"
) as decompose_mock,
):
decompose_mock.return_value = True
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
FileCheck().check_not("extern_kernels.bmm_dtype").check(
"decompose_k"
).check(
f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)"
).run(code[0])
# Output stride equal to original gm output stride
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
self.assertEqual(out.stride(), (1096, 1))
FileCheck().check_not("extern_kernels.bmm_dtype").check(
"decompose_k"
).check(
f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)"
).run(code[0])
@unittest.skipIf(not torch.version.hip, "ROCM only")
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
@@ -2056,7 +2064,7 @@ class TestMaxAutotune(TestCase):
self.assertEqual(misses(), 4)
@fresh_cache()
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
@skipIfXpu
@unittest.skipIf(
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
)

View File

@@ -1718,8 +1718,11 @@ class triton:
disallow_failing_autotune_kernels_TESTING_ONLY = False
# specify number of splits to autotune on for decompose_k. 0 disables decompose_k
# Disabled on ROCm by default pending performance validation.
num_decompose_k_splits = int(
os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10")
os.environ.get(
"TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "0" if torch.version.hip else "10"
)
)
# specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables

View File

@@ -4,8 +4,6 @@ from typing import Any, TYPE_CHECKING
import sympy
import torch
from ..ir import get_free_symbols
from ..kernel.mm import decompose_k_subgraph_template
from ..kernel_inputs import KernelInputs, MMKernelInputs
@@ -30,15 +28,13 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
"xpu",
op_name="mm",
)
# on CUDA, we don't support hip for decompose_k yet
# Register on CUDA (both NVIDIA and ROCm/HIP)
# Runtime enablement is controlled by config.triton.num_decompose_k_splits (0 disables)
@register_template_heuristic(
decompose_k_subgraph_template.uid,
"cuda",
register=torch.version.hip is None,
op_name="mm",
)
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
# and benchmarking it for performance and stability
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
# by either adding specific register_template_heuristic tags, or setting the
# device to None (enabled on all devices)

View File

@@ -2169,8 +2169,7 @@ def use_decompose_k_choice(
decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple
return (
not torch.version.hip
and V.graph.sizevars.statically_known_true(
V.graph.sizevars.statically_known_true(
sympy.And(
sympy.Ge(k, decompose_k_threshold * m),
sympy.Ge(k, decompose_k_threshold * n),