[dynamo] replace unimplemented with unimplemented_v2 in variables/torch_functions.py (#151278)

This addresses part of #147913.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151278
Approved by: https://github.com/Skylion007, https://github.com/williamwen42
ghstack dependencies: #151277
This commit is contained in:
Ryan Guo
2025-05-02 09:44:06 -07:00
committed by PyTorch MergeBot
parent 9e24f9b523
commit 51e77f3b30
2 changed files with 38 additions and 11 deletions

View File

@@ -769,7 +769,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
def fn(x):
return x.ndim
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
msg = "`torch.compile` only support tracing certain types of overriden tensor subclass attributes"
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
x = torch.ones(2, 2).as_subclass(LocalSubclass)
fn(x)

View File

@@ -44,7 +44,8 @@ from torch.overrides import (
)
from torch.utils._device import DeviceContext
from ..exc import unimplemented
from .. import graph_break_hints
from ..exc import unimplemented_v2
from ..guards import GuardBuilder, install_guard
from ..polyfills import NoEnterTorchFunctionMode
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
@@ -567,8 +568,13 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
return res
unimplemented(
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
unimplemented_v2(
gb_type="TypeError from user code",
context=f"{fn=}, {args=}, {kwargs=}",
explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented",
hints=[
*graph_break_hints.USER_ERROR,
],
)
@@ -621,9 +627,17 @@ class TensorWithTFOverrideVariable(TensorVariable):
# base tensors, custom attribute accesses will graph break.
import torch
# I think only `_base` is breaking because we aren't modelling view
# relationship perfectly in some scenarios.
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
unimplemented_v2(
gb_type="Unsupported tensor subclass attribute access",
context=f"{name}",
explanation="`torch.compile` currently can't trace this",
hints=[
f"Avoid accessing {name} of tensor subclass in torch.compile region",
*graph_break_hints.SUPPORTABLE,
],
)
# Handle non-overriden attributes inherited from `torch.Tensor`.
@@ -676,8 +690,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
)
elif attr_is_overriden:
unimplemented(
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
unimplemented_v2(
gb_type="Unsupported tensor subclass overriden attribute access",
context=f"{name}",
explanation="`torch.compile` only support tracing certain types of overriden tensor subclass attributes",
hints=[
f"Avoid accessing {name} of tensor subclass in torch.compile region",
f"Renaming attribute `{name}` of type {self.class_type}",
*graph_break_hints.SUPPORTABLE,
],
)
return super().var_getattr(tx, name)
@@ -709,9 +730,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
import torch
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
unimplemented_v2(
gb_type="Tensor subclass overriden method call",
context=f"{name}",
explanation="`torch.compile` currently can't trace this",
hints=[
f"Avoid calling {name} of tensor subclass in torch.compile region",
f"Renaming method `{name}` of type {self.class_type}",
*graph_break_hints.SUPPORTABLE,
],
)
# [Note: __torch_function__] Currently we only support methods that are defined on tensor