mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[dynamo] replace unimplemented with unimplemented_v2 in variables/torch_functions.py (#151278)
This addresses part of #147913. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151278 Approved by: https://github.com/Skylion007, https://github.com/williamwen42 ghstack dependencies: #151277
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e24f9b523
commit
51e77f3b30
@@ -769,7 +769,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
def fn(x):
|
||||
return x.ndim
|
||||
|
||||
msg = "Currently only support accessing overridden attributes that are functions or properties, but got <class 'int'>"
|
||||
msg = "`torch.compile` only support tracing certain types of overriden tensor subclass attributes"
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
|
||||
x = torch.ones(2, 2).as_subclass(LocalSubclass)
|
||||
fn(x)
|
||||
|
||||
@@ -44,7 +44,8 @@ from torch.overrides import (
|
||||
)
|
||||
from torch.utils._device import DeviceContext
|
||||
|
||||
from ..exc import unimplemented
|
||||
from .. import graph_break_hints
|
||||
from ..exc import unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..polyfills import NoEnterTorchFunctionMode
|
||||
from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
|
||||
@@ -567,8 +568,13 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
|
||||
if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
|
||||
return res
|
||||
|
||||
unimplemented(
|
||||
f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
|
||||
unimplemented_v2(
|
||||
gb_type="TypeError from user code",
|
||||
context=f"{fn=}, {args=}, {kwargs=}",
|
||||
explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented",
|
||||
hints=[
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -621,9 +627,17 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
|
||||
# I think only `_base` is breaking because we aren't modelling view
|
||||
# relationship perfectly in some scenarios.
|
||||
if name in banned_attrs:
|
||||
unimplemented(
|
||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported tensor subclass attribute access",
|
||||
context=f"{name}",
|
||||
explanation="`torch.compile` currently can't trace this",
|
||||
hints=[
|
||||
f"Avoid accessing {name} of tensor subclass in torch.compile region",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
# Handle non-overriden attributes inherited from `torch.Tensor`.
|
||||
@@ -676,8 +690,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
)
|
||||
|
||||
elif attr_is_overriden:
|
||||
unimplemented(
|
||||
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported tensor subclass overriden attribute access",
|
||||
context=f"{name}",
|
||||
explanation="`torch.compile` only support tracing certain types of overriden tensor subclass attributes",
|
||||
hints=[
|
||||
f"Avoid accessing {name} of tensor subclass in torch.compile region",
|
||||
f"Renaming attribute `{name}` of type {self.class_type}",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
@@ -709,9 +730,15 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
import torch
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Calling overridden method {name} on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
unimplemented_v2(
|
||||
gb_type="Tensor subclass overriden method call",
|
||||
context=f"{name}",
|
||||
explanation="`torch.compile` currently can't trace this",
|
||||
hints=[
|
||||
f"Avoid calling {name} of tensor subclass in torch.compile region",
|
||||
f"Renaming method `{name}` of type {self.class_type}",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
|
||||
# [Note: __torch_function__] Currently we only support methods that are defined on tensor
|
||||
|
||||
Reference in New Issue
Block a user