From 5a607febc04c3a2b5824c75f3f60307867439a2c Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 27 Nov 2025 20:20:06 +0000 Subject: [PATCH] Back out "Make PT2 compile backprop through custom op without autograd key a hard error (#166367)" (#168142) Summary: Original commit changeset: 7148dc4803f5 Original Phabricator Diff: D86736500 Differential Revision: D87407335 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168142 Approved by: https://github.com/wdvr --- aten/src/ATen/native/TensorCompare.cpp | 9 -- aten/src/ATen/native/native_functions.yaml | 5 -- test/distributed/test_inductor_collectives.py | 10 ++- test/test_autograd_fallback.py | 11 ++- torch/_functorch/aot_autograd.py | 4 - torch/_higher_order_ops/effects.py | 1 - torch/_library/autograd.py | 11 --- torch/_subclasses/fake_impls.py | 5 -- .../autograd_not_implemented_fallback.cpp | 90 +++++++------------ torch/fx/node.py | 1 - torchgen/native_function_generation.py | 1 - 11 files changed, 48 insertions(+), 100 deletions(-) diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 8a0b38eafab..1a3843e9cdc 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -480,14 +479,6 @@ Tensor isfinite(const Tensor& self) { }); } -void _async_error(std::string_view msg) { - TORCH_CHECK(0, msg); -} - -void _async_error_meta(std::string_view msg) { - // Do NOT error, it's an async error! -} - void _assert_async_cpu(const Tensor& self) { TORCH_CHECK( native::is_nonzero(self), diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4fa24ff378d..81a782f7332 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -192,11 +192,6 @@ CompositeExplicitAutograd: _assert_tensor_metadata Meta: _assert_tensor_metadata_meta_symint -- func: _async_error(str msg) -> () - dispatch: - CompositeExplicitAutograd: _async_error - Meta: _async_error_meta - - func: _print(str s) -> () dispatch: CompositeExplicitAutograd: _print diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fdf03fdf3a1..52062616a85 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1348,11 +1348,13 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): assert counter.op_count == 3 # It generates 2 getattr to unpack the array assert same(out, correct) - # This doesn't work in all cases, and now we properly loudly error. - # See: https://github.com/pytorch/pytorch/issues/151240 - # When differentiable funcols are implemented can revert. - @unittest.expectedFailure def test_backwards(self): + """ + It's probably not that common to need backwards support for collectives. + + However, I wanted to at least see if it was possible to support it as a design goal. + """ + def func(inp): ar = _functional_collectives.all_reduce(inp, "sum", "0") return ar diff --git a/test/test_autograd_fallback.py b/test/test_autograd_fallback.py index 5748b5c4cca..d6252ac6f34 100644 --- a/test/test_autograd_fallback.py +++ b/test/test_autograd_fallback.py @@ -6,7 +6,6 @@ import warnings import numpy as np import torch -from torch._library.autograd import autograd_fallback_mode from torch.library import _scoped_library from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -16,6 +15,16 @@ from torch.testing._internal.common_utils import ( ) +@contextlib.contextmanager +def autograd_fallback_mode(mode): + prev = torch._C._get_autograd_fallback_mode() + try: + torch._C._set_autograd_fallback_mode(mode) + yield + finally: + torch._C._set_autograd_fallback_mode(prev) + + class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 8555026122e..9fdebe6396d 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -26,7 +26,6 @@ from torch._dynamo.utils import ( from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.utils import BoxedBool -from torch._library.autograd import autograd_fallback_mode from torch._subclasses import FakeTensor, FakeTensorMode from torch.export._tree_utils import reorder_kwargs from torch.fx.experimental.proxy_tensor import make_fx @@ -529,9 +528,6 @@ def create_aot_state( stack.enter_context( torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing() ) - # Make it an error to backprop through PT2 compliant ops that silently - # detach autograd - stack.enter_context(autograd_fallback_mode("error")) from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj from torch._library.opaque_object import is_opaque_type diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index b2fc74b7328..86707a4f55e 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -59,7 +59,6 @@ def _get_effect(op: _op_identifier) -> Optional[_EffectType]: _register_effectful_op("aten::_print", _EffectType.ORDERED) -_register_effectful_op("aten::_async_error", _EffectType.ORDERED) _register_effectful_op("profiler::_record_function_exit._RecordFunction", None) _register_effectful_op(call_torchbind, _EffectType.ORDERED) _register_effectful_op(hop_print, _EffectType.ORDERED) diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 125ed5b73d8..2707d07059e 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import contextlib import dataclasses from collections.abc import Callable from dataclasses import dataclass @@ -236,16 +235,6 @@ def not_list_of_optional_tensor(tree): return True -@contextlib.contextmanager -def autograd_fallback_mode(mode): - prev = _C._get_autograd_fallback_mode() - try: - _C._set_autograd_fallback_mode(mode) - yield - finally: - _C._set_autograd_fallback_mode(prev) - - flatten = _pytree.tree_flatten unflatten = _pytree.tree_unflatten spec_t = _pytree.TreeSpec diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 530c8d939d7..ff309af8a29 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -223,11 +223,6 @@ def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs): return r -@register_op_impl(aten._async_error.default) -def _async_error(fake_mode, func, msg: str): - pass - - @register_op_impl(aten.to.prim_Device) @register_op_impl(aten.to.device) def non_kwarg_to(fake_mode, func, *args, **kwargs): diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index a4a9afec1a7..386a8a9df53 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -6,12 +6,6 @@ #include #include -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - #include #include #include @@ -70,6 +64,7 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn; } // namespace void setAutogradFallbackMode(AutogradFallbackMode mode) { + TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'"); kAutogradFallbackMode = mode; } @@ -77,60 +72,41 @@ AutogradFallbackMode getAutogradFallbackMode() { return kAutogradFallbackMode; } -static void reportAutogradNotImplemented( - const std::string& op_name, - bool is_warn) { - if (is_warn) { - TORCH_WARN( - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", - "This behavior is deprecated and will be removed in a future version of PyTorch. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "DispatchKey::CompositeImplicitAutograd). If your operator is not " - "differentiable, or to squash this warning and use the previous behavior, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); - } else { - at::_async_error(c10::str( - op_name, - ": an autograd kernel was not registered to the Autograd key(s) ", - "but we are trying to backprop through it. This can lead to silently incorrect behavior. ", - "If your operator is differentiable, please ensure you have registered an " - "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " - "). If your operator is not " - "differentiable and ensure NO gradients flow through this operator, " - "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")); - } +static void warnAutogradNotImplemented(const std::string& op_name) { + TORCH_WARN( + op_name, + ": an autograd kernel was not registered to the Autograd key(s) ", + "but we are trying to backprop through it. This may lead to silently incorrect behavior. ", + "This behavior is deprecated and will be removed in a future version of PyTorch. ", + "If your operator is differentiable, please ensure you have registered an " + "autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, " + "DispatchKey::CompositeImplicitAutograd). If your operator is not " + "differentiable, or to squash this warning and use the previous behavior, " + "please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd."); } -struct NotImplementedBackward : public Node { - NotImplementedBackward( +struct WarnNotImplemented : public Node { + WarnNotImplemented( std::string op_name, size_t num_outputs, - bool is_warn, edge_list&& next_edges) : Node(std::move(next_edges)), op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + num_outputs(num_outputs) {} - NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn) - : op_name(std::move(op_name)), - num_outputs(num_outputs), - is_warn(is_warn) {} + WarnNotImplemented(std::string op_name, size_t num_outputs) + : op_name(std::move(op_name)), num_outputs(num_outputs) {} variable_list apply(variable_list&& inputs) override; std::string op_name; size_t num_outputs; - bool is_warn; }; // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) -auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list { +auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list { auto inputsLocal = std::move(inputs); - reportAutogradNotImplemented(op_name, is_warn); + warnAutogradNotImplemented(op_name); std::vector output(num_outputs); return output; } @@ -149,6 +125,8 @@ static void basicAutogradNotImplementedFallbackImpl( op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack); return; } + TORCH_INTERNAL_ASSERT( + getAutogradFallbackMode() == AutogradFallbackMode::Warn); bool any_input_requires_grad = false; _foreach_tensor( @@ -164,9 +142,7 @@ static void basicAutogradNotImplementedFallbackImpl( // by putting it after the requires_grad checks. any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled(); - bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn; - - std::shared_ptr grad_fn; + std::shared_ptr grad_fn; if (any_input_requires_grad) { // NB: It is standard to collect edges from all tensors // (see generated/VariableTypeEverything.cpp for examples) @@ -178,9 +154,8 @@ static void basicAutogradNotImplementedFallbackImpl( stack, stack_start, num_arguments); - grad_fn = std::shared_ptr( - new NotImplementedBackward( - op_name, all_tensors_on_stack.size(), is_warn), + grad_fn = std::shared_ptr( + new WarnNotImplemented(op_name, all_tensors_on_stack.size()), deleteNode); grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack)); } @@ -216,8 +191,8 @@ static void basicAutogradNotImplementedFallbackImpl( // >>> y = op(k) // >>> torch.autograd.grad(z.sum(), w) if (t.requires_grad()) { - t.register_hook([op_name, is_warn](const at::Tensor& grad) { - reportAutogradNotImplemented(op_name, is_warn); + t.register_hook([op_name](const at::Tensor& grad) { + warnAutogradNotImplemented(op_name); }); // If history is rebased, then we will attempt to warn // on the view's base. This will catch most cases (because @@ -227,19 +202,18 @@ static void basicAutogradNotImplementedFallbackImpl( const auto& base = t._base(); if (base.requires_grad()) { // Can only register_hook on tensors that require grad. - base.register_hook( - [op_name, is_warn](const at::TensorBase& grad) { - reportAutogradNotImplemented(op_name, is_warn); - }); + base.register_hook([op_name](const at::TensorBase& grad) { + warnAutogradNotImplemented(op_name); + }); } } return; } // If the post-autograd implementation returns any Tensors that - // don't require grad, then we install the NotImplementedBackward - // grad_fn. This grad_fn warns in backward and returns undefined - // tensor gradients. + // don't require grad, then we install the WarnNotImplemented grad_fn. + // This grad_fn warns in backward and returns undefined tensor + // gradients. // // NOTE [autograd fallback and in-place operations] // If the schema says the output is mutable, and the output diff --git a/torch/fx/node.py b/torch/fx/node.py index 294e15c5502..5afabe40ec3 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -90,7 +90,6 @@ _side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [ _side_effectful_functions: set[Callable[..., Any]] = { torch._assert, torch._assert_async, - _ops.aten._async_error.default, _ops.aten._assert_async.msg, _ops.aten._assert_scalar.default, _ops.aten._assert_tensor_metadata.default, diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 6cbb0568289..f986c77f8fa 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -55,7 +55,6 @@ MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ # All of these operators don't have any tensor like returns FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [ - "_async_error", "_assert_async", # no return "_assert_async.msg", # no return "_assert_tensor_metadata", # no return