[OpenReg][1/N] Migrate cpp_extensions_open_device_registration to OpenReg (#156588)

----

- fake tensor
- named tensor
- custom autograd function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156588
Approved by: https://github.com/albanD
This commit is contained in:
FFFrog
2025-06-25 17:47:22 +08:00
committed by PyTorch MergeBot
parent 4585c33e74
commit a730c65fe3
4 changed files with 100 additions and 85 deletions

View File

@@ -139,36 +139,6 @@ void fallback_with_undefined_tensor() {
grad_scale, found_inf);
}
struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
return self;
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
return self.view_symint(self.sym_sizes());
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
return CustomAutogradFnReturnsSelf::apply(x);
}
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
return CustomAutogradFnAliasing::apply(x);
}
// Here, we're exposing a custom device object that corresponds to our custom backend.
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
// that's implemented in C++.
@@ -179,14 +149,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
// Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
}
TORCH_LIBRARY(_test_funcs, m) {
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
}
TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
}

View File

@@ -2,22 +2,26 @@
#include <ATen/EmptyTensor.h>
#include <ATen/TensorIterator.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/ops/set_native.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <c10/core/Allocator.h>
#include <torch/library.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/library.h>
namespace openreg {
namespace {
@@ -275,6 +279,44 @@ void quantize_tensor_per_tensor_affine_privateuse1(
// Just test the process, so do nothing
}
struct CustomAutogradFnReturnsSelf
: public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self;
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
struct CustomAutogradFnAliasing
: public torch::autograd::Function<CustomAutogradFnAliasing> {
static at::Tensor forward(
torch::autograd::AutogradContext* ctx,
at::Tensor self) {
return self.view_symint(self.sym_sizes());
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
return CustomAutogradFnReturnsSelf::apply(x);
}
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
return CustomAutogradFnAliasing::apply(x);
}
/* Notes:
*
* OpenReg is currently designed to simulate device memory through multiple
@@ -362,3 +404,15 @@ REGISTER_PRIVATEUSE1_DISPATCH(
_fused_sdp_choice_stub,
&openreg::_fused_sdp_choice_privateuse1);
} // namespace at::native
TORCH_LIBRARY(openreg, m) {
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
}
TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing);
m.impl(
"custom_autograd_fn_returns_self",
&openreg::custom_autograd_fn_returns_self);
}

View File

@@ -53,46 +53,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
verbose=True,
)
def test_open_device_faketensor(self):
with torch._subclasses.fake_tensor.FakeTensorMode.push():
a = torch.empty(1, device="openreg")
b = torch.empty(1, device="openreg:0")
result = a + b # noqa: F841
def test_open_device_named_tensor(self):
torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
# Not an open registration test - this file is just very convenient
# for testing torch.compile on custom C++ operators
def test_compile_autograd_function_returns_self(self):
x_ref = torch.randn(4, requires_grad=True)
out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
out_ref.sum().backward()
x_test = x_ref.detach().clone().requires_grad_(True)
f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
out_test = f_compiled(x_test)
out_test.sum().backward()
self.assertEqual(out_ref, out_test)
self.assertEqual(x_ref.grad, x_test.grad)
# Not an open registration test - this file is just very convenient
# for testing torch.compile on custom C++ operators
@common.skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
def test_compile_autograd_function_aliasing(self):
x_ref = torch.randn(4, requires_grad=True)
out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
out_ref.sum().backward()
x_test = x_ref.detach().clone().requires_grad_(True)
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
out_test = f_compiled(x_test)
out_test.sum().backward()
self.assertEqual(out_ref, out_test)
self.assertEqual(x_ref.grad, x_test.grad)
def test_open_device_scalar_type_fallback(self):
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
z = torch.triu_indices(3, 3, device="openreg")

View File

@@ -382,6 +382,15 @@ class TestOpenReg(TestCase):
self.assertEqual(z.device.type, "openreg")
self.assertEqual(z.shape, torch.Size([0]))
def test_fake_tensor(self):
with torch._subclasses.fake_tensor.FakeTensorMode():
a = torch.empty(1, device="openreg")
b = torch.empty(1, device="openreg:0")
result = a + b # noqa: F841
def test_named_tensor(self):
return torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
def test_printing(self):
a = torch.ones(20, device="openreg")
# Does not crash!
@@ -424,6 +433,38 @@ class TestOpenReg(TestCase):
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
self.assertEqual(quantized_tensor.dtype, torch.qint8)
# custom autograd
def test_compile_autograd_function_returns_self(self):
in_ref = torch.randn(4, device="openreg", requires_grad=True)
out_ref = torch.ops.openreg.custom_autograd_fn_returns_self(in_ref)
out_ref.sum().backward()
in_test = in_ref.detach().clone().requires_grad_(True)
# TODO(FFFrog): Need to support inductor for OpenReg first.
out_test = torch.compile(backend="aot_eager")(
torch.ops.openreg.custom_autograd_fn_returns_self
)(in_test)
out_test.sum().backward()
self.assertEqual(out_ref, out_test)
self.assertEqual(in_ref.grad, in_test.grad)
@skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
def test_compile_autograd_function_aliasing(self):
in_ref = torch.randn(4, device="openreg", requires_grad=True)
out_ref = torch.ops.openreg.custom_autograd_fn_aliasing(in_ref)
out_ref.sum().backward()
in_test = in_ref.detach().clone().requires_grad_(True)
# TODO(FFFrog): Need to support inductor for OpenReg first.
out_test = torch.compile(backend="aot_eager")(
torch.ops.openreg.custom_autograd_fn_aliasing
)(in_test)
out_test.sum().backward()
self.assertEqual(out_ref, out_test)
self.assertEqual(in_ref.grad, in_test.grad)
if __name__ == "__main__":
run_tests()