diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 662f5420bfc..7185682df70 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -7,6 +7,11 @@ import unittest import torch import torch._dynamo.test_case from torch.testing._internal.common_utils import IS_FBCODE +from torch.testing._internal.inductor_utils import requires_triton +from torch.utils._triton import ( + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, +) def _filter_instructions(instructions, opname): @@ -397,6 +402,52 @@ class ReconstructTest(torch._dynamo.test_case.TestCase): inp = torch.randn(3) self.assertEqual(gn(inp), inp + 3) + @requires_triton() + @unittest.skipIf( + not has_triton_experimental_host_tma(), + "Test requires triton.tools.experimental_descriptor API", + ) + def test_tma_experimental_reconstruct(self): + import triton + + def create_tma(tensor): + tma = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + tensor.data_ptr(), + tensor.size(0), + tensor.size(1), + 32, + 32, + tensor.element_size(), + ) + return tensor + 1, tma + + x = torch.randn(128, 128, device="cuda") + + ref = create_tma(x) + res = torch.compile(create_tma, backend="eager")(x) + self.assertEqual(ref[1].desc, res[1].desc) + + @requires_triton() + @unittest.skipIf( + not has_triton_tensor_descriptor_host_tma(), + "Test requires triton.tools.tensor_descriptor API", + ) + def test_tma_stable_reconstruct(self): + import triton + + def create_tma(tensor): + tma = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + tensor, + [32, 32], + ) + return tensor + 1, tma + + x = torch.randn(128, 128, device="cuda") + + ref = create_tma(x) + res = torch.compile(create_tma, backend="eager")(x) + self.assertEqual(ref, res) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 45ca28fd645..b9c1beed121 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -54,7 +54,8 @@ from .distributed import BackwardHookVariable, DistributedVariable, PlacementVar from .functions import ( BuiltinMethodVariable, CollectionsNamedTupleFunction, - CreateTMADescriptorVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, FunctionDecoratedByContextlibContextManagerVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, @@ -63,7 +64,8 @@ from .functions import ( NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, - TMADescriptorVariable, + TMADescriptorExperimentalVariable, + TMADescriptorStableVariable, UserFunctionVariable, UserMethodVariable, ) @@ -157,7 +159,8 @@ __all__ = [ "ConstDictVariable", "ContextWrappingVariable", "CountIteratorVariable", - "CreateTMADescriptorVariable", + "CreateTMADescriptorExperimentalVariable", + "CreateTMADescriptorStableVariable", "CUDADeviceVariable", "CycleIteratorVariable", "DataPtrVariable", @@ -198,7 +201,8 @@ __all__ = [ "SuperVariable", "TemporarilyPopInterpreterStackCtxManagerVariable", "TensorVariable", - "TMADescriptorVariable", + "TMADescriptorExperimentalVariable", + "TMADescriptorStableVariable", "TorchCtxManagerClassVariable", "TorchInGraphFunctionVariable", "TorchVersionVariable", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 878dcb8a013..da628411305 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -189,7 +189,8 @@ from .functions import ( BuiltinMethodVariable, CollectionsNamedTupleFunction, CollectiveFunctionRewriteVariable, - CreateTMADescriptorVariable, + CreateTMADescriptorExperimentalVariable, + CreateTMADescriptorStableVariable, FunctoolsPartialVariable, FunctoolsWrapsVariable, SysFunctionVariable, @@ -606,7 +607,11 @@ class VariableBuilder: def _wrap(self, value): # import here to avoid circular dependencies - from torch.utils._triton import has_triton, has_triton_tma + from torch.utils._triton import ( + has_triton, + has_triton_experimental_host_tma, + has_triton_tensor_descriptor_host_tma, + ) from ..decorators import DynamoConfigPatchProxy @@ -621,18 +626,25 @@ class VariableBuilder: class Autotuner: pass - if has_triton_tma(): - from triton.tools.experimental_descriptor import ( + # default implementations, in case we don't have triton (or the wrong triton version) + def create_1d_tma_descriptor(): + pass + + def create_2d_tma_descriptor(): + pass + + class TensorDescriptor: + @staticmethod + def from_tensor(): + pass + + if has_triton_experimental_host_tma(): + from triton.tools.experimental_descriptor import ( # noqa: F811 create_1d_tma_descriptor, create_2d_tma_descriptor, ) - else: - - def create_1d_tma_descriptor(): - pass - - def create_2d_tma_descriptor(): - pass + if has_triton_tensor_descriptor_host_tma(): + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 # Handle exact type() match type_dispatch = self._type_dispatch().get(type(value)) @@ -1111,9 +1123,11 @@ class VariableBuilder: source=self.source, ) elif value is create_1d_tma_descriptor: - return CreateTMADescriptorVariable(rank=1) + return CreateTMADescriptorExperimentalVariable(rank=1) elif value is create_2d_tma_descriptor: - return CreateTMADescriptorVariable(rank=2) + return CreateTMADescriptorExperimentalVariable(rank=2) + elif value is TensorDescriptor.from_tensor: + return CreateTMADescriptorStableVariable() elif isinstance(value, torch.amp.autocast_mode.autocast): self.install_guards(GuardBuilder.ID_MATCH) return AutocastModeVariable( diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index d569aa8e800..bfbd04e14cd 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -2059,16 +2059,19 @@ class DynamoTritonHOPifier(TritonHOPifier): from .dicts import ConstDictVariable # as we can only pass tensors as non-const args in fx graph, - # here we replace TMA descriptors (TMADescriptorVariable + # here we replace TMA descriptors + # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable # instances) with the underlying tensors, while moving the # TMA descriptor-related metadata to a separate argument, # so that we can reconstruct the TMA descriptors downstream tma_descriptor_metadata: TMADescriptorMetadata = {} for k in list(combined_args_raw.keys()): v = combined_args_raw[k] - if isinstance(v, TMADescriptorVariable): + if isinstance( + v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable) + ): tma_descriptor_metadata[k] = v.to_metadata() - combined_args_raw[k] = v.data_ptr.from_tensor + combined_args_raw[k] = v.get_tensor() combined_args = { variables.ConstantVariable.create(k): v @@ -2170,7 +2173,7 @@ class TritonKernelVariable(VariableTracker): return arg -class TMADescriptorVariable(VariableTracker): +class TMADescriptorExperimentalVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", @@ -2205,8 +2208,45 @@ class TMADescriptorVariable(VariableTracker): codegen.foreach(args) codegen.call_function(len(args) + 1, False) + def get_tensor(self): + return self.data_ptr.from_tensor -class CreateTMADescriptorVariable(VariableTracker): + +class TMADescriptorStableVariable(VariableTracker): + def __init__( + self, + tensor: "variables.TensorVariable", + block_shape: "variables.ListVariable", + **kwargs, + ): + assert isinstance(tensor, variables.TensorVariable) + super().__init__(**kwargs) + self.tensor = tensor + self.block_shape = block_shape + + def to_metadata(self): + # TODO(dberard) implement this + raise NotImplementedError( + "TensorDescriptor.from_tensor support is not yet implemented" + ) + + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.tensor_descriptor", + "TensorDescriptor", + ) + ) + codegen.load_method("from_tensor") + self.tensor.reconstruct(codegen) + codegen(self.block_shape) + codegen.call_method(2) + + def get_tensor(self) -> "variables.TensorVariable": + return self.tensor + + +class CreateTMADescriptorExperimentalVariable(VariableTracker): def __init__( self, rank: int, @@ -2251,9 +2291,25 @@ class CreateTMADescriptorVariable(VariableTracker): ] element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] - return TMADescriptorVariable( + return TMADescriptorExperimentalVariable( data_ptr=ptr, dims=dims, block_dims=block_dims, element_size=element_size, ) + + +class CreateTMADescriptorStableVariable(VariableTracker): + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] + block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] + + return TMADescriptorStableVariable( + tensor=tensor, + block_shape=block_shape, + )