diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 37adb0282c9..ffd7e55d233 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index d1b3b17445d..3b432305107 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py +torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index 31e78d0245d..99af489cc21 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +def mirror_inductor_external_kernels() -> None: + """ + Copy external kernels into Inductor so they are importable. + """ + paths = [ + ( + CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", + CWD + / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", + ), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + + # Copy the files from the orig location to the new location + if orig_path.is_file(): + shutil.copyfile(orig_path, new_path) + continue + if orig_path.is_dir(): + if new_path.exists(): + # copytree fails if the tree exists already, so remove it. + shutil.rmtree(new_path) + shutil.copytree(orig_path, new_path) + continue + raise RuntimeError( + "Check the file paths in `mirror_inductor_external_kernels()`" + ) + + # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1615,6 +1646,7 @@ def main() -> None: mirror_files_into_torchgen() if RUN_BUILD_DEPS: build_deps() + mirror_inductor_external_kernels() ( ext_modules, @@ -1649,6 +1681,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", + "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py new file mode 100644 index 00000000000..c26def3a540 --- /dev/null +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -0,0 +1,154 @@ +# Owner(s): ["module: inductor"] + + +import unittest + +import torch +from torch import Tensor +from torch._inductor import config +from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import ensure_cute_available +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +@unittest.skipIf( + not (ensure_cute_available() and is_datacenter_blackwell_arch()), + "CuTeDSL library or Blackwell device not available", +) +@instantiate_parametrized_tests +class TestCuTeDSLGroupedGemm(InductorTestCase): + def _get_inputs( + self, + group_size: int, + M_hint: int, + K: int, + N: int, + device: str, + dtype: torch.dtype, + alignment: int = 16, + ) -> tuple[Tensor, Tensor, Tensor]: + # --- Random, tile-aligned M sizes --- + M_sizes = ( + torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) + * alignment + ) + + M_total = torch.sum(M_sizes).item() + + # --- Construct input tensors --- + A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 + B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 + + # --- Build offsets (no leading zero, strictly increasing) --- + offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) + + return (A, B, offsets) + + @parametrize("group_size", (2, 8)) + @parametrize("M_hint", (256, 1024)) + @parametrize("K", (64, 128)) + @parametrize("N", (128, 256)) + def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): + device = "cuda" + dtype = torch.bfloat16 + + A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # Eager execution + c_eager = grouped_gemm_fn(A, B, offsets) + + # Test with Cute backend + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) + @parametrize("layout_B", ("contiguous", "broadcasted")) + def test_grouped_gemm_assorted_layouts( + self, + layout_A: str, + layout_B: str, + ): + device = "cuda" + dtype = torch.bfloat16 + + G, K, N = 8, 64, 128 + M_sizes = [128] * G + sum_M = sum(M_sizes) + offsets = torch.tensor( + [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device + ) + + A_base = torch.randn(sum_M, K, device=device, dtype=dtype) + A = A_base + + if layout_A == "offset": + # allocate bigger buffer than needed, use nonzero storage offset + storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) + offset = 128 # skip first 128 elements + A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) + elif layout_A == "padded": + # simulate row pitch > K (row_stride = K + pad) + row_pitch = K + 8 + storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) + A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) + elif layout_A == "view": + A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) + A = A_storage.view(sum_M, K) + assert A._base is not None + assert A.shape == (sum_M, K) + + B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 + + if layout_B == "broadcasted": + # Broadcast B across groups (zero stride along G) + B = B[0].expand(G, K, N) + assert B.stride(0) == 0 + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # --- eager --- + c_eager = grouped_gemm_fn(A, B, offsets) + + # --- compiled (CUTE backend) --- + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bd0ff91616b..094850eced4 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -550,6 +550,10 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b95073e769f..eb22b95af2a 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence +from functools import partial +from pathlib import Path from typing import Any import torch @@ -12,6 +14,7 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template log = logging.getLogger(__name__) @@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 881c14fd43d..c81ec607661 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import logging -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl +from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -22,11 +24,13 @@ from ..utils import ( get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, + load_kernel_template, persistent_grouped_mm_grid, ) @@ -513,6 +517,11 @@ triton_scaled_grouped_mm_template = TritonTemplate( source=triton_grouped_mm_source, ) +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + def grouped_mm_args( mat1: TensorBox, @@ -714,43 +723,44 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -788,6 +798,22 @@ def _tuned_grouped_mm_common( **config.kwargs, ) + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 00000000000..989f297c5f8 --- /dev/null +++ b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 00000000000..db337b9d8a2 --- /dev/null +++ b/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 05b1b9bd33a..f98d3385b18 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V