mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[ROCm] enable decompose k tests for functional coverage (#169948)
Fixes https://github.com/pytorch/pytorch/issues/168617 Fixes https://github.com/pytorch/pytorch/issues/168615 Fixes https://github.com/pytorch/pytorch/issues/168614 Fixes https://github.com/pytorch/pytorch/issues/168613 Fixes https://github.com/pytorch/pytorch/issues/168599 Fixes https://github.com/pytorch/pytorch/issues/168600 Fixes https://github.com/pytorch/pytorch/issues/168601 Fixes https://github.com/pytorch/pytorch/issues/168602 Fixes https://github.com/pytorch/pytorch/issues/168603 Fixes https://github.com/pytorch/pytorch/issues/168604 Fixes https://github.com/pytorch/pytorch/issues/168605 Fixes https://github.com/pytorch/pytorch/issues/168606 Fixes https://github.com/pytorch/pytorch/issues/168607 Enables testing for decompose K mode on ROCm. This is still disabled by default pending perf testing but we can get the functional coverage by adding an inductor config for decompose k enablement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169948 Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/PaulZhang12
This commit is contained in:
committed by
PyTorch MergeBot
parent
9ade6aad80
commit
3c98eef883
@@ -23,12 +23,17 @@ from torch._inductor.virtualized import V
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU
|
||||
from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device
|
||||
|
||||
|
||||
# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere
|
||||
_DECOMPOSE_K_PATCH_ROCM = (
|
||||
{"triton.num_decompose_k_splits": 10} if torch.version.hip else {}
|
||||
)
|
||||
|
||||
|
||||
class MockTensorNode:
|
||||
"""Mock input node that wraps a real tensor for testing"""
|
||||
|
||||
@@ -833,7 +838,6 @@ class BaseE2ELookupTableTest(BaseLookupTableTest):
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support lookup table")
|
||||
@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available")
|
||||
@instantiate_parametrized_tests
|
||||
class TestLookupTableE2E(BaseE2ELookupTableTest):
|
||||
@@ -928,21 +932,26 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
|
||||
operation, tensors, {"triton.enable_persistent_tma_matmul": True}
|
||||
)
|
||||
|
||||
# Enable decompose_k for this test (disabled by default on ROCm)
|
||||
@fresh_cache()
|
||||
def test_decompose_k_lookup_table_entry(self):
|
||||
"""Test decompose_k template entry"""
|
||||
tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32)
|
||||
config = self.create_basic_config(
|
||||
torch._inductor.kernel.mm.decompose_k_subgraph_template.uid
|
||||
)
|
||||
|
||||
self.setup_lookup_table("mm", tensors, [config])
|
||||
add_preprocessing_fn(
|
||||
partial(
|
||||
verify_choice_names, pattern="decompose_k|bmm_dtype", expected_count=1
|
||||
with inductor_config.patch(_DECOMPOSE_K_PATCH_ROCM):
|
||||
tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32)
|
||||
config = self.create_basic_config(
|
||||
torch._inductor.kernel.mm.decompose_k_subgraph_template.uid
|
||||
)
|
||||
)
|
||||
self.run_model("mm", tensors)
|
||||
|
||||
self.setup_lookup_table("mm", tensors, [config])
|
||||
add_preprocessing_fn(
|
||||
partial(
|
||||
verify_choice_names,
|
||||
pattern="decompose_k|bmm_dtype",
|
||||
expected_count=1,
|
||||
)
|
||||
)
|
||||
|
||||
self.run_model("mm", tensors)
|
||||
|
||||
@fresh_cache()
|
||||
def test_bias_addmm_lookup_table_entry(self):
|
||||
|
||||
@@ -105,6 +105,11 @@ else:
|
||||
if HAS_CUDA_AND_TRITON:
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||
|
||||
# Conditional patch for decompose_k tests - override to 10 on ROCm, no-op elsewhere
|
||||
_DECOMPOSE_K_PATCH_ROCM = (
|
||||
{"triton.num_decompose_k_splits": 10} if torch.version.hip else {}
|
||||
)
|
||||
|
||||
|
||||
def benchmark_choice(choice, args, out, expected_out, timings):
|
||||
result = choice.benchmark(*args, out=out)
|
||||
@@ -1214,101 +1219,100 @@ class TestMaxAutotune(TestCase):
|
||||
shape_padding=False,
|
||||
)
|
||||
def test_max_autotune_decompose_k(self, sizes, dtype, dynamic):
|
||||
fp16_red_setting = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
)
|
||||
bf16_red_setting = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
|
||||
fp16_red_setting = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
)
|
||||
bf16_red_setting = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
M, N, K = sizes
|
||||
M, N, K = sizes
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
a = torch.randn(M, K, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
b = torch.randn(K, N, dtype=dtype, device=GPU_TYPE, requires_grad=True)
|
||||
|
||||
possible_splits = range(2, min(K // M, K // N) + 1)
|
||||
possible_splits = range(2, min(K // M, K // N) + 1)
|
||||
|
||||
divisors = {split for split in possible_splits if K % split == 0}
|
||||
divisors = {split for split in possible_splits if K % split == 0}
|
||||
|
||||
def check_divisors(code):
|
||||
for kernel in code:
|
||||
if "decompose_k" in kernel:
|
||||
divisor_found = False
|
||||
for divisor in divisors:
|
||||
if f"{divisor}_split" in kernel:
|
||||
divisor_found = True
|
||||
break
|
||||
def check_divisors(code):
|
||||
for kernel in code:
|
||||
if "decompose_k" in kernel:
|
||||
divisor_found = False
|
||||
for divisor in divisors:
|
||||
if f"{divisor}_split" in kernel:
|
||||
divisor_found = True
|
||||
break
|
||||
|
||||
self.assertTrue(
|
||||
divisor_found,
|
||||
f"Could not find a split in {divisors} in {kernel}",
|
||||
)
|
||||
self.assertTrue(
|
||||
divisor_found,
|
||||
f"Could not find a split in {divisors} in {kernel}",
|
||||
)
|
||||
|
||||
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
|
||||
# We assume with the large k dim relative to m, n, decompose_k will be most performant
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
compiled_func = torch.compile(lambda a, b: a @ b, dynamic=dynamic)
|
||||
# We assume with the large k dim relative to m, n, decompose_k will be most performant
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
if dynamic or torch.version.hip:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Test adding epilogue also equivalent to eager
|
||||
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
if dynamic or torch.version.hip:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_mm_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
|
||||
# Test adding epilogue also equivalent to eager
|
||||
compiled_func = torch.compile(lambda a, b: (a @ b).relu(), dynamic=dynamic)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_mm_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b), (a @ b).relu(), atol=1e-2, rtol=1e-2
|
||||
)
|
||||
|
||||
# Test adding reinterpret view before subgraph
|
||||
a = a.transpose(0, 1)
|
||||
compiled_func = torch.compile(
|
||||
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
|
||||
)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
if dynamic:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b),
|
||||
(a.transpose(0, 1) @ b).relu(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
fp16_red_setting
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
bf16_red_setting
|
||||
)
|
||||
|
||||
# Test adding reinterpret view before subgraph
|
||||
a = a.transpose(0, 1)
|
||||
compiled_func = torch.compile(
|
||||
lambda a, b: (a.transpose(0, 1) @ b).relu(), dynamic=dynamic
|
||||
)
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
|
||||
# DecomposeK is not enabled for AMD yet
|
||||
if dynamic or torch.version.hip:
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check_not(
|
||||
"decompose_k"
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
compiled_func(a, b),
|
||||
(a.transpose(0, 1) @ b).relu(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
|
||||
fp16_red_setting
|
||||
)
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
||||
bf16_red_setting
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
@@ -1356,7 +1360,6 @@ class TestMaxAutotune(TestCase):
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
@@ -1369,44 +1372,45 @@ class TestMaxAutotune(TestCase):
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
)
|
||||
def test_max_autotune_decompose_k_dynamic_input_bwd(self):
|
||||
def f(a, b):
|
||||
# 256 * s0
|
||||
a_in = torch.cat([a for _ in range(256)], dim=0)
|
||||
return (a_in @ b).relu().sum()
|
||||
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
|
||||
|
||||
a = torch.randn(
|
||||
8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
b = torch.randn(
|
||||
64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
def f(a, b):
|
||||
# 256 * s0
|
||||
a_in = torch.cat([a for _ in range(256)], dim=0)
|
||||
return (a_in @ b).relu().sum()
|
||||
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.maybe_mark_dynamic(a, 0)
|
||||
compiled_func = torch.compile(f)
|
||||
res = compiled_func(a, b)
|
||||
res.backward()
|
||||
|
||||
with mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decomp_mock:
|
||||
decomp_mock.side_effect = (
|
||||
lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1
|
||||
a = torch.randn(
|
||||
8, 64, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
b = torch.randn(
|
||||
64, 32768, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
out.backward()
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.maybe_mark_dynamic(a, 0)
|
||||
compiled_func = torch.compile(f)
|
||||
res = compiled_func(a, b)
|
||||
res.backward()
|
||||
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
|
||||
r"256\*s[0-9]+"
|
||||
).check_regex("s[0-9]+ = 8").run(
|
||||
# code[1] in this case given backwards
|
||||
code[1]
|
||||
)
|
||||
with mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decomp_mock:
|
||||
decomp_mock.side_effect = (
|
||||
lambda *args, **kwargs: kwargs.get("threshold_multiple", 1) == 1
|
||||
)
|
||||
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
out.backward()
|
||||
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
|
||||
r"256\*s[0-9]+"
|
||||
).check_regex("s[0-9]+ = 8").run(
|
||||
# code[1] in this case given backwards
|
||||
code[1]
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
@@ -1419,38 +1423,42 @@ class TestMaxAutotune(TestCase):
|
||||
max_autotune_gemm_backends="TRITON",
|
||||
)
|
||||
def test_max_autotune_decompose_k_output_stride(self):
|
||||
def f(a, b):
|
||||
a = a.transpose(0, 1)
|
||||
return a @ b
|
||||
with config.patch(_DECOMPOSE_K_PATCH_ROCM):
|
||||
|
||||
a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
def f(a, b):
|
||||
a = a.transpose(0, 1)
|
||||
return a @ b
|
||||
|
||||
b = b[:, :1096]
|
||||
a = torch.randn((32768, 256), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
b = torch.randn((32768, 1152), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
|
||||
# Force only decomposeK choice
|
||||
with (
|
||||
override_template_heuristics(
|
||||
device_type=GPU_TYPE,
|
||||
template_op_pairs=[(torch._inductor.kernel.mm.mm_template.name, "mm")],
|
||||
),
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decompose_mock,
|
||||
):
|
||||
decompose_mock.return_value = True
|
||||
compiled_f = torch.compile(f)
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
b = b[:, :1096]
|
||||
|
||||
# Output stride equal to original gm output stride
|
||||
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
|
||||
self.assertEqual(out.stride(), (1096, 1))
|
||||
# Force only decomposeK choice
|
||||
with (
|
||||
override_template_heuristics(
|
||||
device_type=GPU_TYPE,
|
||||
template_op_pairs=[
|
||||
(torch._inductor.kernel.mm.mm_template.name, "mm")
|
||||
],
|
||||
),
|
||||
mock.patch(
|
||||
"torch._inductor.kernel.mm.use_decompose_k_choice"
|
||||
) as decompose_mock,
|
||||
):
|
||||
decompose_mock.return_value = True
|
||||
compiled_f = torch.compile(f)
|
||||
out, code = run_and_get_code(compiled_f, a, b)
|
||||
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check(
|
||||
"decompose_k"
|
||||
).check(
|
||||
f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)"
|
||||
).run(code[0])
|
||||
# Output stride equal to original gm output stride
|
||||
# If output stride is not correctly checked, this will be (1152, 1) which can cause nans
|
||||
self.assertEqual(out.stride(), (1096, 1))
|
||||
|
||||
FileCheck().check_not("extern_kernels.bmm_dtype").check(
|
||||
"decompose_k"
|
||||
).check(
|
||||
f" empty_strided_{GPU_TYPE}((256, 1096), (1096, 1), torch.bfloat16)"
|
||||
).run(code[0])
|
||||
|
||||
@unittest.skipIf(not torch.version.hip, "ROCM only")
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32))
|
||||
@@ -2056,7 +2064,7 @@ class TestMaxAutotune(TestCase):
|
||||
self.assertEqual(misses(), 4)
|
||||
|
||||
@fresh_cache()
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm")
|
||||
@skipIfXpu
|
||||
@unittest.skipIf(
|
||||
config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet"
|
||||
)
|
||||
|
||||
@@ -1718,8 +1718,11 @@ class triton:
|
||||
disallow_failing_autotune_kernels_TESTING_ONLY = False
|
||||
|
||||
# specify number of splits to autotune on for decompose_k. 0 disables decompose_k
|
||||
# Disabled on ROCm by default pending performance validation.
|
||||
num_decompose_k_splits = int(
|
||||
os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10")
|
||||
os.environ.get(
|
||||
"TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "0" if torch.version.hip else "10"
|
||||
)
|
||||
)
|
||||
|
||||
# specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from ..ir import get_free_symbols
|
||||
from ..kernel.mm import decompose_k_subgraph_template
|
||||
from ..kernel_inputs import KernelInputs, MMKernelInputs
|
||||
@@ -30,15 +28,13 @@ class EmptyDecomposeKConfigHeuristics(TemplateConfigHeuristics):
|
||||
"xpu",
|
||||
op_name="mm",
|
||||
)
|
||||
# on CUDA, we don't support hip for decompose_k yet
|
||||
# Register on CUDA (both NVIDIA and ROCm/HIP)
|
||||
# Runtime enablement is controlled by config.triton.num_decompose_k_splits (0 disables)
|
||||
@register_template_heuristic(
|
||||
decompose_k_subgraph_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
op_name="mm",
|
||||
)
|
||||
# TODO(coconutruben): enable decompose k on AMD by removing the register bool
|
||||
# and benchmarking it for performance and stability
|
||||
# TODO(coconutruben): enable decompose k on other devices (xpu, cpu, mps, mtia)
|
||||
# by either adding specific register_template_heuristic tags, or setting the
|
||||
# device to None (enabled on all devices)
|
||||
|
||||
@@ -2169,8 +2169,7 @@ def use_decompose_k_choice(
|
||||
decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple
|
||||
|
||||
return (
|
||||
not torch.version.hip
|
||||
and V.graph.sizevars.statically_known_true(
|
||||
V.graph.sizevars.statically_known_true(
|
||||
sympy.And(
|
||||
sympy.Ge(k, decompose_k_threshold * m),
|
||||
sympy.Ge(k, decompose_k_threshold * n),
|
||||
|
||||
Reference in New Issue
Block a user