From 466ea8ce54777dee62e1d2d9bb4cb111910b17f4 Mon Sep 17 00:00:00 2001 From: albanD Date: Fri, 26 Jul 2024 15:40:17 -0400 Subject: [PATCH] Add fallback() to torch.library (#131707) Pull Request resolved: https://github.com/pytorch/pytorch/pull/131707 Approved by: https://github.com/zou3519 --- c10/core/impl/PyInterpreter.cpp | 3 +- c10/core/impl/PyInterpreter.h | 3 +- .../TestPythonRegistration.test_fallback | 0 ...estPythonRegistration.test_fallback_keyset | 0 test/test_python_dispatch.py | 131 ++++++++++++++++++ torch/csrc/PyInterpreter.cpp | 81 ++++++----- torch/csrc/PyInterpreter.h | 6 + torch/csrc/utils/python_dispatch.cpp | 74 ++++++++-- torch/csrc/utils/python_dispatch.h | 3 +- torch/library.py | 39 ++++++ 10 files changed, 286 insertions(+), 54 deletions(-) create mode 100644 test/dynamo_skips/TestPythonRegistration.test_fallback create mode 100644 test/dynamo_skips/TestPythonRegistration.test_fallback_keyset diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 04f9d7c9722..51622572f6a 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -36,7 +36,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { c10::DispatchKey, c10::DispatchKeySet keyset, torch::jit::Stack* stack, - bool with_keyset) const override { + bool with_keyset, + bool with_op) const override { PANIC(python_op_registration_trampoline); } diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 2685496899e..c0976c841ff 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -150,7 +150,8 @@ struct C10_API PyInterpreterVTable { c10::DispatchKey, c10::DispatchKeySet keyset, torch::jit::Stack* stack, - bool with_keyset) const = 0; + bool with_keyset, + bool with_op) const = 0; virtual void throw_abstract_impl_not_imported_error( std::string opname, diff --git a/test/dynamo_skips/TestPythonRegistration.test_fallback b/test/dynamo_skips/TestPythonRegistration.test_fallback new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/dynamo_skips/TestPythonRegistration.test_fallback_keyset b/test/dynamo_skips/TestPythonRegistration.test_fallback_keyset new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index efb30ff1c12..0b3d9d487f2 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -69,6 +69,137 @@ class TestPythonRegistration(TestCase): if hasattr(torch.ops, self.test_ns): del torch.ops._test_python_registration + def test_fallback(self) -> None: + test_key = "TESTING_ONLY_GenericMode" + test_keyset = torch._C.DispatchKeySet(test_key) + include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset + exclude_to_set = torch._C._dispatch_tls_local_exclude_set() + + with _scoped_library("_", "IMPL") as my_lib: + expected_op = None + expected_args = None + expected_kwargs = None + # Use this out shape to make sure the result from our fallback + # is what is returned to the user + out_shape = None + + def my_fallback(op, *args, **kwargs): + # Disable our handler during checks and generating the output + with torch._C._ForceDispatchKeyGuard( + include_to_set, exclude_to_set | test_keyset + ): + self.assertIs(op, expected_op) + self.assertEqual(args, expected_args) + self.assertEqual(kwargs, expected_kwargs) + # Return something specific + return torch.empty(out_shape) + + my_lib.fallback(my_fallback, test_key) + + a, b = torch.rand(2), torch.rand(2) + + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + # Check a factory function + expected_op = torch.ops.aten.empty.memory_format + expected_args = ((2, 2),) + # Extra kwargs to bypass issues with default args in factory functions + expected_kwargs = { + "dtype": torch.float64, + "pin_memory": False, + "device": torch.device("cpu"), + } + out_shape = (3,) + out = torch.empty(*expected_args, **expected_kwargs) + self.assertEqual(out.size(), out_shape) + + # Check a regular function + expected_op = torch.ops.aten.add.Tensor + expected_args = (a, b) + expected_kwargs = {} + out_shape = (4,) + out = a + b + self.assertEqual(out.size(), out_shape) + + def test_fallback_keyset(self) -> None: + test_key_first = "TESTING_ONLY_GenericMode" + test_key_second = "TESTING_ONLY_GenericWrapper" + test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( + test_key_second + ) + include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset + exclude_to_set = torch._C._dispatch_tls_local_exclude_set() + + with _scoped_library("_", "IMPL") as my_lib: + first_called = False + second_called = False + + def first_fallback(keyset, op, *args, **kwargs): + nonlocal first_called + if second_called: + # Recursive call + first_called = True + with torch._C._ForceDispatchKeyGuard( + include_to_set, exclude_to_set | test_keyset + ): + return op(*args, **kwargs) + else: + # Redispatch down + keyset = keyset.remove(test_key_first) + return op.redispatch(keyset, *args, **kwargs) + + def second_fallback(op, *args, **kwargs): + nonlocal second_called + # Set to avoid infinite recursion + second_called = True + # New dispatcher call should hit the first callback again + self.assertFalse(first_called) + a, b = args + # Make a substraction here instead of add ! + c = a - b + self.assertTrue(first_called) + return c + + my_lib.fallback(first_fallback, test_key_first, with_keyset=True) + my_lib.fallback(second_fallback, test_key_second) + + a, b = torch.rand(2), torch.rand(2) + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + c = a + b + + self.assertEqual(c, a - b) + self.assertTrue(first_called) + self.assertTrue(second_called) + + def test_fallback_fallthrough(self) -> None: + test_key_first = "TESTING_ONLY_GenericMode" + test_key_second = "TESTING_ONLY_GenericWrapper" + test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( + test_key_second + ) + include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset + exclude_to_set = torch._C._dispatch_tls_local_exclude_set() + + with _scoped_library("_", "IMPL") as my_lib: + is_called = False + + def my_fallback(op, *args, **kwargs): + nonlocal is_called + is_called = True + with torch._C._ForceDispatchKeyGuard( + include_to_set, exclude_to_set | test_keyset + ): + return op(*args, **kwargs) + + my_lib.fallback(torch.library.fallthrough_kernel, test_key_first) + my_lib.fallback(my_fallback, test_key_second) + + a, b = torch.rand(2), torch.rand(2) + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + c = a + b + + self.assertEqual(c, a + b) + self.assertTrue(is_called) + def test_override_aten_ops_with_multiple_libraries(self) -> None: x = torch.tensor([1, 2]) with _scoped_library("aten", "IMPL") as my_lib2: diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index a7e5c5e9fb8..adcbb8a9cf0 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -12,6 +12,8 @@ using namespace torch; using namespace at; using namespace c10; +namespace torch::detail { + namespace { // NB: This is a macro and not a template function (like it was before) @@ -62,9 +64,10 @@ struct ConcretePyInterpreterVTable final c10::DispatchKey key, c10::DispatchKeySet keyset, torch::jit::Stack* stack, - bool with_keyset) const override { + bool with_keyset, + bool with_op) const override { torch::impl::dispatch::python_op_registration_trampoline_impl( - op, key, keyset, stack, with_keyset); + op, key, keyset, stack, with_keyset, with_op); } void throw_abstract_impl_not_imported_error( std::string opname, @@ -272,30 +275,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot) Py_DECREF(pyobj); }; -py::handle getTorchApiFunction(const c10::OperatorHandle& op) { - return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { - // Parse the name into namespace and name (no overload_name) - // TODO: put this into the library - const auto& schema = op.schema(); - const auto& qualified_name = op.operator_name().name; - const auto& overload_name = schema.overload_name(); - auto pos = qualified_name.find("::"); - TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); - // Make me some null terminated strings - std::string ns_str = qualified_name.substr(0, pos); - const char* ns = ns_str.c_str(); - const char* func_name = qualified_name.c_str() + pos + strlen("::"); - - py::handle torch_api_function = - py::module::import("torch").attr("ops").attr(ns).attr(func_name); - if (overload_name.empty()) { - return torch_api_function.attr("default").ptr(); - } else { - return torch_api_function.attr(overload_name.c_str()).ptr(); - } - }); -} - bool isPythonTensor(const at::Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } @@ -956,20 +935,46 @@ void ConcretePyInterpreterVTable::reset_backward_hooks( END_HANDLE_TH_ERRORS_PYBIND } -PyInterpreterHolder self_interpreter; - -} // anonymous namespace - -c10::impl::PyInterpreter* getPyInterpreter() { - return self_interpreter.get(); -} - -bool isMainPyInterpreter() { - return self_interpreter.is_main_interpreter(); -} - std::string ConcretePyInterpreterVTable::name() const { std::stringstream ss; ss << getPyInterpreter(); return ss.str(); } + +PyInterpreterHolder self_interpreter; + +} // anonymous namespace + +py::handle getTorchApiFunction(const c10::OperatorHandle& op) { + return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { + // Parse the name into namespace and name (no overload_name) + // TODO: put this into the library + const auto& schema = op.schema(); + const auto& qualified_name = op.operator_name().name; + const auto& overload_name = schema.overload_name(); + auto pos = qualified_name.find("::"); + TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); + // Make me some null terminated strings + std::string ns_str = qualified_name.substr(0, pos); + const char* ns = ns_str.c_str(); + const char* func_name = qualified_name.c_str() + pos + strlen("::"); + + py::handle torch_api_function = + py::module::import("torch").attr("ops").attr(ns).attr(func_name); + if (overload_name.empty()) { + return torch_api_function.attr("default").ptr(); + } else { + return torch_api_function.attr(overload_name.c_str()).ptr(); + } + }); +} + +} // namespace torch::detail + +c10::impl::PyInterpreter* getPyInterpreter() { + return torch::detail::self_interpreter.get(); +} + +bool isMainPyInterpreter() { + return torch::detail::self_interpreter.is_main_interpreter(); +} diff --git a/torch/csrc/PyInterpreter.h b/torch/csrc/PyInterpreter.h index 30809ff10be..82ca11e2c5d 100644 --- a/torch/csrc/PyInterpreter.h +++ b/torch/csrc/PyInterpreter.h @@ -2,6 +2,12 @@ #include #include +#include +namespace torch::detail { +TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op); +} + +// TODO: Move these to a proper namespace TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); TORCH_PYTHON_API bool isMainPyInterpreter(); diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 5807debfeff..eda4a721338 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -108,15 +108,19 @@ class PythonKernelHolder : public c10::OperatorKernel { c10::DispatchKey dispatch_key_; // If "with_keyset", then we expect a keyset as the first arg. bool with_keyset_; + // If "with_op", then we expect the op as first arg (or second if keyset) + bool with_op_; public: PythonKernelHolder( py::object func, c10::DispatchKey dispatch_key, - bool with_keyset = false) + bool with_keyset = false, + bool with_op = false) : func_(func.release().ptr(), getPyInterpreter()), dispatch_key_(dispatch_key), - with_keyset_(with_keyset) {} + with_keyset_(with_keyset), + with_op_(with_op) {} void operator()( const c10::OperatorHandle& op, @@ -132,7 +136,7 @@ class PythonKernelHolder : public c10::OperatorKernel { c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); cur_torch_dispatch_mode_state->pyinterpreter() ->python_op_registration_trampoline( - op, dispatch_key_, keyset, stack, with_keyset_); + op, dispatch_key_, keyset, stack, with_keyset_, with_op_); return; } @@ -150,7 +154,7 @@ class PythonKernelHolder : public c10::OperatorKernel { at::DispatchKey::Python)) { (*interpreter) ->python_op_registration_trampoline( - op, dispatch_key_, keyset, stack, with_keyset_); + op, dispatch_key_, keyset, stack, with_keyset_, with_op_); return; } } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { @@ -166,7 +170,7 @@ class PythonKernelHolder : public c10::OperatorKernel { nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) { (*interpreter) ->python_op_registration_trampoline( - op, dispatch_key_, keyset, stack, with_keyset_); + op, dispatch_key_, keyset, stack, with_keyset_, with_op_); return; } } @@ -189,9 +193,18 @@ class PythonKernelHolder : public c10::OperatorKernel { auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto func = py::reinterpret_borrow(func_.ptr(getPyInterpreter())); - auto obj = with_keyset_ - ? func(keyset, *args_kwargs.first, **args_kwargs.second) - : func(*args_kwargs.first, **args_kwargs.second); + auto obj = with_op_ ? with_keyset_ + ? func( + keyset, + torch::detail::getTorchApiFunction(op), + *args_kwargs.first, + **args_kwargs.second) + : func( + torch::detail::getTorchApiFunction(op), + *args_kwargs.first, + **args_kwargs.second) + : with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second) + : func(*args_kwargs.first, **args_kwargs.second); if (!obj) { throw python_error(); } @@ -461,7 +474,33 @@ void initDispatchBindings(PyObject* module) { return self; }, "", - py::arg("dispatch") = ""); + py::arg("dispatch") = "") + .def( + "fallback", + [](const py::object& self, + c10::DispatchKey dispatch, + const py::object& func, + bool with_keyset) { + HANDLE_TH_ERRORS + auto& lib = self.cast(); + TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); + if (func.is(py::module::import("torch.library") + .attr("fallthrough_kernel"))) { + lib.fallback( + torch::dispatch(dispatch, CppFunction::makeFallthrough())); + } else { + lib.fallback(torch::dispatch( + dispatch, + CppFunction::makeFromBoxedFunctor( + std::make_unique( + func, dispatch, with_keyset, /*with_op*/ true)))); + } + END_HANDLE_TH_ERRORS_PYBIND + }, + "", + py::arg("dispatch"), + py::arg("func"), + py::arg("with_keyset") = false); m.def( "_dispatch_library", @@ -954,7 +993,8 @@ void python_op_registration_trampoline_impl( c10::DispatchKey key, c10::DispatchKeySet keyset, torch::jit::Stack* stack, - bool with_keyset) { + bool with_keyset, + bool with_op) { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); @@ -963,9 +1003,17 @@ void python_op_registration_trampoline_impl( auto* pyobj = func->ptr(getPyInterpreter()); TORCH_INTERNAL_ASSERT(pyobj != nullptr); auto callable = py::reinterpret_borrow(pyobj); - auto obj = with_keyset - ? callable(keyset, *args_kwargs.first, **args_kwargs.second) - : callable(*args_kwargs.first, **args_kwargs.second); + auto obj = with_op ? with_keyset ? callable( + keyset, + torch::detail::getTorchApiFunction(op), + *args_kwargs.first, + **args_kwargs.second) + : callable( + torch::detail::getTorchApiFunction(op), + *args_kwargs.first, + **args_kwargs.second) + : with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second) + : callable(*args_kwargs.first, **args_kwargs.second); if (!obj) { throw python_error(); } diff --git a/torch/csrc/utils/python_dispatch.h b/torch/csrc/utils/python_dispatch.h index 32d436d8347..aeb20a33859 100644 --- a/torch/csrc/utils/python_dispatch.h +++ b/torch/csrc/utils/python_dispatch.h @@ -10,6 +10,7 @@ void python_op_registration_trampoline_impl( c10::DispatchKey key, c10::DispatchKeySet keyset, torch::jit::Stack* stack, - bool with_keyset); + bool with_keyset, + bool with_op); } // namespace torch::impl::dispatch diff --git a/torch/library.py b/torch/library.py index abf7db574ae..5a282799c75 100644 --- a/torch/library.py +++ b/torch/library.py @@ -278,6 +278,8 @@ class Library: to register a fallthrough. dispatch_key: dispatch key that the input function should be registered for. By default, it uses the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. Example:: >>> my_lib = Library("aten", "IMPL") @@ -345,6 +347,43 @@ class Library: _impls.add(key) self._op_impls.add(key) + def fallback(self, fn, dispatch_key="", *, with_keyset=False): + r"""Registers the function implementation as the fallback for the given key. + + This function only works for a library with global namespace ("_"). + + Args: + fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` + to register a fallthrough. + dispatch_key: dispatch key that the input function should be registered for. By default, it uses + the dispatch key that the library was created with. + with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument + to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. + + Example:: + >>> my_lib = Library("_", "IMPL") + >>> def fallback_kernel(op, *args, **kwargs): + >>> # Handle all autocast ops generically + >>> # ... + >>> my_lib.fallback(fallback_kernel, "Autocast") + """ + if dispatch_key == "": + dispatch_key = self.dispatch_key + + if self.ns != "_": + raise RuntimeError( + f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" + ) + + assert dispatch_key != "" + assert self.m is not None + + self.m.fallback( + dispatch_key, + fn, + with_keyset, + ) + def _destroy(self): if self.m is not None: self.m.reset()