From df72078fe1339751e702c7511c23b4597d022dcc Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 8 Jul 2025 00:46:56 +0000 Subject: [PATCH] [dynamo] Replace unimplemented with unimplemented_v2 in `torch/_dynamo/variables/torch.py` (#157344) Fixes part of #147913 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157344 Approved by: https://github.com/williamwen42 Co-authored-by: William Wen --- test/dynamo/test_decorators.py | 40 +--- test/dynamo/test_modes.py | 2 +- torch/_dynamo/graph_break_registry.json | 276 +++++++++++++++++++++- torch/_dynamo/variables/torch.py | 295 +++++++++++++++++++----- 4 files changed, 525 insertions(+), 88 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 6cd8fafa234..70e1946c309 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -671,13 +671,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): fn(p) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: - msg = """ -For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type .Point>, please use one of the following to register the type with pytree: - * `torch.utils._pytree.register_constant` - * `torch.utils._pytree.register_dataclass` - * `torch.utils._pytree.register_pytree_node` -""" # NOQA: B950 - self.assertIn(msg, str(e)) + self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) def test_nonstrict_trace_nested_custom_class_error(self): class Point: @@ -723,13 +717,7 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types fn(torch.ones(10), torch.ones(1)) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: - msg = """ -For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type .Point>, please use one of the following to register the type with pytree: - * `torch.utils._pytree.register_constant` - * `torch.utils._pytree.register_dataclass` - * `torch.utils._pytree.register_pytree_node` -""" # NOQA: B950 - self.assertIn(msg, str(e)) + self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) def test_nonstrict_newly_constructed_trace_register_constant_type_error(self): class State: @@ -766,12 +754,10 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types fn(x) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: - msg = """ -You are calling a `nonstrict_trace`-ed function with an input that contains an object of type .State>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. - -Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. -""" # NOQA: B950 - self.assertIn(msg, str(e)) + self.assertIn( + "Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + str(e), + ) def test_nonstrict_trace_object_in_context_error(self): class Point: @@ -814,17 +800,9 @@ Please construct the object _outside_ the `torch.compile` region, or submit an i fn(x, y) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: - msg = """ -You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type .Point> into the context. - -Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to .Point> - * `torch.utils._pytree.register_constant` - * `torch.utils._pytree.register_dataclass` - * `torch.utils._pytree.register_pytree_node` - -If the above doesn't work, please subtmit an issue to GitHub. -""" # NOQA: B950 - self.assertIn(msg, str(e)) + self.assertIn( + "Invalid use of pytree_flatten with nonstrict_trace-ed function", str(e) + ) def test_graph_break(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 30489302645..868627d0202 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -232,7 +232,7 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase): self.assertRaisesRegex( torch._dynamo.exc.Unsupported, - "Popping from an empty torch function mode stack", + "Attempted to pop from empty torch function mode stack", lambda: fn(torch.ones(2, 2)), ) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 153b7132bf1..9071f723fcb 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2172,9 +2172,6 @@ "Hints": [ "Don't mutate `.data` on this tensor, or move ", "the mutation out of `torch.compile` region" - ], - "Additional_Info": [ - "INFO" ] } ], @@ -2199,5 +2196,278 @@ "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." ] } + ], + "GB0223": [ + { + "Gb_type": "torch.compile call with > 1 args", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.", + "Hints": [ + "Remove the torch.compile call or its additional args.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0224": [ + { + "Gb_type": "Attempted to call torch in-graph function on only torch.SymInt arguments", + "Context": "fn={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. Dynamo does not support this.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0225": [ + { + "Gb_type": "Attempted to use tensor creation function with requires_grad=True", + "Context": "fn={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Create the tensor outside the compiled region.", + "Do not set `requires_grad=True`.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0226": [ + { + "Gb_type": "`torch.nn.Parameter()` with unsupported data type", + "Context": "data={data}", + "Explanation": "Called `torch.nn.Parameter()` with non-Tensor argument.", + "Hints": [ + "Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0227": [ + { + "Gb_type": "Attempted to use torch.nn.Parameter constructor with tensor subclass", + "Context": "str(data)", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0228": [ + { + "Gb_type": "`torch.nn.Parameter`: cannot convert to traceable tracable", + "Context": "", + "Explanation": "convert_tracable_parameter is set to False.", + "Hints": [ + "Check usage of context manager: do_not_convert_to_tracable_parameter", + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0229": [ + { + "Gb_type": "Unexpected type of data placeholder op for parameter construction", + "Context": "data_node.op={data_node.op}", + "Explanation": "Data node op should be placeholder or get_attr.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0230": [ + { + "Gb_type": "Attempted to use torch.use_deterministic_algorithms(warn_only=True)", + "Context": "mode={mode}, warn_only={warn_only}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Remove param warn_only in function call torch.use_deterministic_algorithms.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0231": [ + { + "Gb_type": "call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", + "Context": "trace_numpy={config.trace_numpy}", + "Explanation": "Attempted to call `torch.from_numpy` with config `torch._dynamo.config.trace_numpy` set to `False`.", + "Hints": [ + "Change `torch._dynamo.config.trace_numpy` to `True`." + ] + } + ], + "GB0232": [ + { + "Gb_type": "`torch.from_numpy` with NumPy unavailable", + "Context": "", + "Explanation": "Attempted to call `torch.numpy` but NumPy could not be imported.", + "Hints": [ + "Check NumPy version and installation in your environment.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0233": [ + { + "Gb_type": "Attempted to use strided NestedTensor", + "Context": "layout={layout}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Change layout=torch.jagged.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0234": [ + { + "Gb_type": "Attempted to pop from empty torch function mode stack", + "Context": "", + "Explanation": "Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.", + "Hints": [ + "Do not pop from empty torch function mode stack.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0235": [ + { + "Gb_type": "`torch.nn.Parameter` with non-constant Tensor attributes", + "Context": "data={data}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Ensure the Tensor argument's shape, dtype, and device are correct.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0236": [ + { + "Gb_type": "Invalid input type for nonstrict_trace-ed function", + "Context": "Encountered input of type <{type_name}>.", + "Explanation": "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) or pytree containers of those are allowed as inputs. The provided argument contains an unsupported type.", + "Hints": [ + "Use one of the following to register the type with pytree:\n", + "* `torch.utils._pytree.register_constant`\n", + "* `torch.utils._pytree.register_dataclass`\n", + "* `torch.utils._pytree.register_pytree_node`" + ] + } + ], + "GB0237": [ + { + "Gb_type": "non-constant `requires_grad` argument to `torch.nn.Parameter`", + "Context": "requires_grad={requires_grad}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Change `requires_grad` to be a bool.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0238": [ + { + "Gb_type": "Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + "Context": "Input={input_spec_vt}, offending type <{type_name}>.", + "Explanation": "Calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. This is not supported.", + "Hints": [ + "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0239": [ + { + "Gb_type": "Invalid use of pytree_flatten with nonstrict_trace-ed function", + "Context": "Input={input_spec_vt}, offending type <{type_name}>.", + "Explanation": "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered with a `pytree_flatten` that places an object of type <{type_name}> into the context.", + "Hints": [ + "Modifying the `pytree_flatten` to avoid placing the object into the context.", + "Apply one of the following to <{type_name}>:\n", + "* `torch.utils._pytree.register_constant`\n", + "* `torch.utils._pytree.register_dataclass`\n", + "* `torch.utils._pytree.register_pytree_node`", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0240": [ + { + "Gb_type": "Shape mismatch with out= list of tensor variants", + "Context": "fn={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0241": [ + { + "Gb_type": "Attempted to call op with non-contiguous `out=` list of tensors", + "Context": "self.value={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0242": [ + { + "Gb_type": "Attempted to call op with non-contiguous `out=` tensor", + "Context": "self.value={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0243": [ + { + "Gb_type": "Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", + "Context": "value={value}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Change use of _ntuple with argument as constant or tensor." + ] + } + ], + "GB0244": [ + { + "Gb_type": "Attempted to use `torch.nn.Parameter()` with export", + "Context": "", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Do not use `torch.nn.Parameter()` with export.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0245": [ + { + "Gb_type": "Attempted to use `nested_tensor` with non-list input", + "Context": "tensor_list={tensor_list}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Change `nested_tensor` with list input.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0246": [ + { + "Gb_type": "Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Dynamo does not support this.", + "Hints": [ + "Explicitly set the `num_classes` param of the function call ", + "`torch.nn.functional.one_hot` to something other than -1." + ] + } + ], + "GB0247": [ + { + "Gb_type": "Shape mismatch with out= tensor variant", + "Context": "fn={self.value}, args={args}, kwargs={kwargs}", + "Explanation": "Shape mismatch when calling {self.value} with `out=`. Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } \ No newline at end of file diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index ced64f34561..372d0ead589 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -52,7 +52,7 @@ from ..create_parameter_op import ( tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces -from ..exc import unimplemented, unimplemented_v2 +from ..exc import unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import CallFunctionNoArgsSource, SyntheticLocalSource from ..utils import ( @@ -588,7 +588,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): # torch.compile is a no-op in dynamo return args[0] - unimplemented("torch.compile is used as a decorator in the compiled frame") + unimplemented_v2( + gb_type="torch.compile call with > 1 args", + context=f"args={args}, kwargs={kwargs}", + explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.", + hints=[ + "Remove the torch.compile call or its additional args.", + *graph_break_hints.SUPPORTABLE, + ], + ) @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): @@ -615,7 +623,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): self, tx: "InstructionTranslator", mode, warn_only=False ): if warn_only and warn_only.as_python_constant(): - unimplemented("torch.use_deterministic_algorithms(warn_only=True)") + unimplemented_v2( + gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", + context=f"mode={mode}, warn_only={warn_only}", + explanation="Dynamo does not support this.", + hints=[ + "Remove param warn_only in function call torch.use_deterministic_algorithms.", + *graph_break_hints.SUPPORTABLE, + ], + ) return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant()) @register(torch.are_deterministic_algorithms_enabled) @@ -666,9 +682,27 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): @register(torch.from_numpy) def handle_from_numpy(self, tx: "InstructionTranslator", *args): if not config.trace_numpy: - unimplemented("torch.from_numpy. config.trace_numpy is False") + unimplemented_v2( + gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", + context=f"trace_numpy={config.trace_numpy}", + explanation=( + "Attempted to call `torch.from_numpy` with config " + "`torch._dynamo.config.trace_numpy` set to `False`." + ), + hints=[ + "Change `torch._dynamo.config.trace_numpy` to `True`.", + ], + ) if not np: - unimplemented("torch.from_numpy. NumPy is not available") + unimplemented_v2( + gb_type="`torch.from_numpy` with NumPy unavailable", + context="", + explanation="Attempted to call `torch.numpy` but NumPy could not be imported.", + hints=[ + "Check NumPy version and installation in your environment.", + *graph_break_hints.USER_ERROR, + ], + ) return wrap_fx_proxy_cls( target_cls=TensorVariable, tx=tx, @@ -880,9 +914,25 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): from .lists import BaseListVariable if layout and layout.as_python_constant() == torch.strided: - unimplemented("torch.compile does not support strided NestedTensor") + unimplemented_v2( + gb_type="Attempted to use strided NestedTensor", + context=f"layout={layout}", + explanation="Dynamo does not support this.", + hints=[ + "Change layout=torch.jagged.", + *graph_break_hints.SUPPORTABLE, + ], + ) if not isinstance(tensor_list, BaseListVariable): - unimplemented("nested_tensor with non-list input") + unimplemented_v2( + gb_type="Attempted to use `nested_tensor` with non-list input", + context=f"tensor_list={tensor_list}", + explanation="Dynamo does not support this.", + hints=[ + "Change `nested_tensor` with list input.", + *graph_break_hints.USER_ERROR, + ], + ) @register(torch.nn.functional.one_hot) def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): @@ -891,8 +941,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): and args[1].is_python_constant() and args[1].as_python_constant() == -1 ): - unimplemented( - "torch.nn.functional.one_hot with data-dependent output shape" + unimplemented_v2( + gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", + context=f"args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Explicitly set the `num_classes` param of the function call " + "`torch.nn.functional.one_hot` to something other than -1.", + ], ) @register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious) @@ -1061,7 +1117,15 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): ): assert not args and not kwargs if not tx.symbolic_torch_function_state.mode_stack: - raise unimplemented("Popping from an empty torch function mode stack") + unimplemented_v2( + gb_type="Attempted to pop from empty torch function mode stack", + context="", + explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.", + hints=[ + "Do not pop from empty torch function mode stack.", + *graph_break_hints.USER_ERROR, + ], + ) TorchFunctionModeStackVariable.register_mutation(tx) return tx.symbolic_torch_function_state.pop_torch_function_mode() @@ -1152,13 +1216,20 @@ class TorchInGraphFunctionVariable(BaseTorchVariable): arg_type = flat_arg_vt.python_type() if not is_graphable_type(arg_type): type_name = flat_arg_vt.python_type().__qualname__ - unimplemented( - f""" -For `nonstrict_trace`-ed function, the only allowed input types are basic types (e.g., torch.Tensor, int, float) or pytree containers of those. Here you are calling the function with arguments that contain a value of type <{type_name}>, please use one of the following to register the type with pytree: - * `torch.utils._pytree.register_constant` - * `torch.utils._pytree.register_dataclass` - * `torch.utils._pytree.register_pytree_node` -""" # NOQA: B950 + unimplemented_v2( + gb_type="Invalid input type for nonstrict_trace-ed function", + context=f"Encountered input of type <{type_name}>.", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, float) " + "or pytree containers of those are allowed as inputs. The provided argument contains " + "an unsupported type." + ), + hints=[ + "Use one of the following to register the type with pytree:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + ], ) # Since we checked with `is_graphable` above, `as_proxy` on the @@ -1179,25 +1250,37 @@ For `nonstrict_trace`-ed function, the only allowed input types are basic types import torch.utils._pytree as pytree if pytree.is_constant_class(typ): - unimplemented( - f""" -You are calling a `nonstrict_trace`-ed function with an input that contains an object of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object was constructed _inside_ the `torch.compile` region. - -Please construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub. - """ # NOQA: B950 + unimplemented_v2( + gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function with an input that contains an object " + f"of type <{type_name}>, which was marked with `pytree.register_constant`. However, the object " + "was constructed _inside_ the `torch.compile` region. This is not supported." + ), + hints=[ + "Construct the object _outside_ the `torch.compile` region, or submit an issue to GitHub.", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, ) else: - unimplemented( - f""" -You are calling a `nonstrict_trace`-ed function where one one of the inputs has been registered with a `pytree_flatten` that puts an object of type <{type_name}> into the context. - -Please consider modifying that `pytree_flatten` to avoid putting the object into context, and apply one of the following to <{type_name}> - * `torch.utils._pytree.register_constant` - * `torch.utils._pytree.register_dataclass` - * `torch.utils._pytree.register_pytree_node` - -If the above doesn't work, please subtmit an issue to GitHub. -""" # NOQA: B950 + unimplemented_v2( + gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", + context=f"Input={input_spec_vt}, offending type <{type_name}>.", + explanation=( + "Calling a `nonstrict_trace`-ed function where one of the inputs has been registered " + f"with a `pytree_flatten` that places an object of type <{type_name}> into the context." + ), + hints=[ + "Modifying the `pytree_flatten` to avoid placing the object into the context.", + f"Apply one of the following to <{type_name}>:\n" + "* `torch.utils._pytree.register_constant`\n" + "* `torch.utils._pytree.register_dataclass`\n" + "* `torch.utils._pytree.register_pytree_node`", + *graph_break_hints.SUPPORTABLE, + ], + from_exc=e, ) fn = self.value @@ -1308,7 +1391,17 @@ To support this behavior, we need to allow const-propping tensors that store sym For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ log.warning(msg) - unimplemented(msg) + unimplemented_v2( + gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Attempted to call {str(self.value)} (that should be put in the FX graph) on only torch.SymInt arguments. " + "Dynamo does not support this." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) # TODO(voz): Replace w/ dynamic shape rewrite table. # Ideally, we would be able to do this at ctor time, but alas we need a combination @@ -1366,9 +1459,15 @@ For now, dynamo will explicitly graph break when it encounters user code with th and "requires_grad" in kwargs and kwargs["requires_grad"].as_python_constant() ): - unimplemented( - """factory functions that return tensors that require grad are not supported. -Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" + unimplemented_v2( + gb_type="Attempted to use tensor creation function with requires_grad=True", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + "Create the tensor outside the compiled region.", + "Do not set `requires_grad=True`.", + *graph_break_hints.SUPPORTABLE, + ], ) # Handle e.g., `torch.add(a, b, out=result)` @@ -1400,12 +1499,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t if saved_out_shape != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") + unimplemented_v2( + gb_type="Shape mismatch with out= list of tensor variants", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shape}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" + unimplemented_v2( + gb_type="Attempted to call op with non-contiguous `out=` list of tensors", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], ) else: assert isinstance(out_kwarg_vt, TensorVariable) @@ -1414,12 +1528,27 @@ Either create the tensor outside the compiled region, or do not set the tensor t if saved_out_shapes != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") + unimplemented_v2( + gb_type="Shape mismatch with out= tensor variant", + context=f"fn={self.value}, args={args}, kwargs={kwargs}", + explanation=( + f"Shape mismatch when calling {self.value} with `out=`. " + f"Provided `out=` shape: {saved_out_shapes}. Actual shape: {fake_out.shape}." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where output tensor was non-contiguous" + unimplemented_v2( + gb_type="Attempted to call op with non-contiguous `out=` tensor", + context=f"self.value={self.value}, args={args}, kwargs={kwargs}", + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], ) return tensor_variable @@ -1444,7 +1573,14 @@ Either create the tensor outside the compiled region, or do not set the tensor t torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), ) else: - unimplemented(f"torch.nn.modules.utils._ntuple({value})") + unimplemented_v2( + gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", + context=f"value={value}", + explanation="Dynamo does not support this.", + hints=[ + "Change use of _ntuple with argument as constant or tensor.", + ], + ) if self.value is torch.nn.modules.utils._ntuple: return variables.LambdaVariable(handle_ntuple) @@ -1455,16 +1591,40 @@ Either create the tensor outside the compiled region, or do not set the tensor t def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" if tx.export: - unimplemented("nn parameter construction not supported with export") + unimplemented_v2( + gb_type="Attempted to use `torch.nn.Parameter()` with export", + context="", + explanation="Dynamo does not support this.", + hints=[ + "Do not use `torch.nn.Parameter()` with export.", + *graph_break_hints.SUPPORTABLE, + ], + ) if isinstance(requires_grad, variables.VariableTracker): try: requires_grad = requires_grad.as_python_constant() except NotImplementedError: - unimplemented("Parameter(requires_grad=...) not constant") + unimplemented_v2( + gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`", + context=f"requires_grad={requires_grad}", + explanation="Dynamo does not support this.", + hints=[ + "Change `requires_grad` to be a bool.", + *graph_break_hints.USER_ERROR, + ], + ) if not isinstance(data, variables.TensorVariable): - unimplemented(f"Parameter(data={data}) not implemented") + unimplemented_v2( + gb_type="`torch.nn.Parameter()` with unsupported data type", + context=f"data={data}", + explanation="Called `torch.nn.Parameter()` with non-Tensor argument.", + hints=[ + "Ensure the argument to `torch.nn.Parameter()` is a `torch.Tensor`.", + *graph_break_hints.USER_ERROR, + ], + ) # this results in cleaner graphs, but only works for inputs if data.source: @@ -1473,17 +1633,41 @@ Either create the tensor outside the compiled region, or do not set the tensor t if isinstance( data, TensorWithTFOverrideVariable ) or is_traceable_wrapper_subclass_type(data.class_type): - unimplemented("Parameter constructor with tensor subclass NYI") + unimplemented_v2( + gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", + context=str(data), + explanation="Dynamo does not support this.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) if not can_convert_to_tracable_parameter(): - unimplemented("Workaround for issues with nn_parameter construction") + unimplemented_v2( + gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable", + context="", + explanation="convert_tracable_parameter is set to False.", + hints=[ + "Check usage of context manager: do_not_convert_to_tracable_parameter", + *graph_break_hints.DIFFICULT, + ], + ) try: shape = tuple(data.var_getattr(tx, "shape").as_python_constant()) dtype = data.var_getattr(tx, "dtype").as_python_constant() device = data.var_getattr(tx, "device").as_python_constant() except NotImplementedError as e: - unimplemented(f"Parameter not python_constant: {e}") + unimplemented_v2( + gb_type="`torch.nn.Parameter` with non-constant Tensor attributes", + context=f"data={data}", + explanation="Dynamo does not support this.", + hints=[ + "Ensure the Tensor argument's shape, dtype, and device are correct.", + *graph_break_hints.USER_ERROR, + ], + from_exc=e, + ) placeholder = tx.output.synthetic_graph_input( new_parameter_placeholder, [shape, dtype, device, requires_grad] @@ -1535,8 +1719,13 @@ Either create the tensor outside the compiled region, or do not set the tensor t data_node = data.as_proxy().node if data_node.op not in ("placeholder", "get_attr"): - unimplemented( - "Unexpected type of data placeholder op for parameter construction" + unimplemented_v2( + gb_type="Unexpected type of data placeholder op for parameter construction", + context=f"data_node.op={data_node.op}", + explanation="Data node op should be placeholder or get_attr.", + hints=[ + *graph_break_hints.DIFFICULT, + ], ) # add the newly constructed nn.Parameter as a graph input