mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[OpenReg][1/N] Migrate cpp_extensions_open_device_registration to OpenReg (#156588)
---- - fake tensor - named tensor - custom autograd function Pull Request resolved: https://github.com/pytorch/pytorch/pull/156588 Approved by: https://github.com/albanD
This commit is contained in:
@@ -139,36 +139,6 @@ void fallback_with_undefined_tensor() {
|
||||
grad_scale, found_inf);
|
||||
}
|
||||
|
||||
struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
|
||||
|
||||
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
|
||||
return {grad_output[0] * 0.5};
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
|
||||
|
||||
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
|
||||
return self.view_symint(self.sym_sizes());
|
||||
}
|
||||
|
||||
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
|
||||
return {grad_output[0] * 0.5};
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
|
||||
return CustomAutogradFnReturnsSelf::apply(x);
|
||||
}
|
||||
|
||||
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
|
||||
return CustomAutogradFnAliasing::apply(x);
|
||||
}
|
||||
|
||||
// Here, we're exposing a custom device object that corresponds to our custom backend.
|
||||
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
|
||||
// that's implemented in C++.
|
||||
@@ -179,14 +149,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
|
||||
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
|
||||
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
|
||||
|
||||
// Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
|
||||
m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(_test_funcs, m) {
|
||||
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
|
||||
}
|
||||
TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
|
||||
m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
|
||||
}
|
||||
|
||||
@@ -2,22 +2,26 @@
|
||||
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
||||
#include <ATen/ops/quantize_per_tensor_native.h>
|
||||
#include <ATen/ops/resize_native.h>
|
||||
#include <ATen/ops/set_cpu_dispatch.h>
|
||||
#include <ATen/ops/set_native.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/native/quantized/AffineQuantizer.h>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <torch/csrc/autograd/custom_function.h>
|
||||
#include <torch/csrc/autograd/function_hook.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace openreg {
|
||||
namespace {
|
||||
|
||||
@@ -275,6 +279,44 @@ void quantize_tensor_per_tensor_affine_privateuse1(
|
||||
// Just test the process, so do nothing
|
||||
}
|
||||
|
||||
struct CustomAutogradFnReturnsSelf
|
||||
: public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
|
||||
static at::Tensor forward(
|
||||
torch::autograd::AutogradContext* ctx,
|
||||
at::Tensor self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
static torch::autograd::variable_list backward(
|
||||
torch::autograd::AutogradContext* ctx,
|
||||
torch::autograd::variable_list grad_output) {
|
||||
return {grad_output[0] * 0.5};
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomAutogradFnAliasing
|
||||
: public torch::autograd::Function<CustomAutogradFnAliasing> {
|
||||
static at::Tensor forward(
|
||||
torch::autograd::AutogradContext* ctx,
|
||||
at::Tensor self) {
|
||||
return self.view_symint(self.sym_sizes());
|
||||
}
|
||||
|
||||
static torch::autograd::variable_list backward(
|
||||
torch::autograd::AutogradContext* ctx,
|
||||
torch::autograd::variable_list grad_output) {
|
||||
return {grad_output[0] * 0.5};
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
|
||||
return CustomAutogradFnReturnsSelf::apply(x);
|
||||
}
|
||||
|
||||
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
|
||||
return CustomAutogradFnAliasing::apply(x);
|
||||
}
|
||||
|
||||
/* Notes:
|
||||
*
|
||||
* OpenReg is currently designed to simulate device memory through multiple
|
||||
@@ -362,3 +404,15 @@ REGISTER_PRIVATEUSE1_DISPATCH(
|
||||
_fused_sdp_choice_stub,
|
||||
&openreg::_fused_sdp_choice_privateuse1);
|
||||
} // namespace at::native
|
||||
|
||||
TORCH_LIBRARY(openreg, m) {
|
||||
m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor");
|
||||
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) {
|
||||
m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing);
|
||||
m.impl(
|
||||
"custom_autograd_fn_returns_self",
|
||||
&openreg::custom_autograd_fn_returns_self);
|
||||
}
|
||||
|
||||
@@ -53,46 +53,6 @@ class TestCppExtensionOpenRegistration(common.TestCase):
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def test_open_device_faketensor(self):
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
||||
a = torch.empty(1, device="openreg")
|
||||
b = torch.empty(1, device="openreg:0")
|
||||
result = a + b # noqa: F841
|
||||
|
||||
def test_open_device_named_tensor(self):
|
||||
torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
|
||||
|
||||
# Not an open registration test - this file is just very convenient
|
||||
# for testing torch.compile on custom C++ operators
|
||||
def test_compile_autograd_function_returns_self(self):
|
||||
x_ref = torch.randn(4, requires_grad=True)
|
||||
out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
x_test = x_ref.detach().clone().requires_grad_(True)
|
||||
f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
|
||||
out_test = f_compiled(x_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(x_ref.grad, x_test.grad)
|
||||
|
||||
# Not an open registration test - this file is just very convenient
|
||||
# for testing torch.compile on custom C++ operators
|
||||
@common.skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
|
||||
def test_compile_autograd_function_aliasing(self):
|
||||
x_ref = torch.randn(4, requires_grad=True)
|
||||
out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
x_test = x_ref.detach().clone().requires_grad_(True)
|
||||
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
|
||||
out_test = f_compiled(x_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(x_ref.grad, x_test.grad)
|
||||
|
||||
def test_open_device_scalar_type_fallback(self):
|
||||
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
|
||||
z = torch.triu_indices(3, 3, device="openreg")
|
||||
|
||||
@@ -382,6 +382,15 @@ class TestOpenReg(TestCase):
|
||||
self.assertEqual(z.device.type, "openreg")
|
||||
self.assertEqual(z.shape, torch.Size([0]))
|
||||
|
||||
def test_fake_tensor(self):
|
||||
with torch._subclasses.fake_tensor.FakeTensorMode():
|
||||
a = torch.empty(1, device="openreg")
|
||||
b = torch.empty(1, device="openreg:0")
|
||||
result = a + b # noqa: F841
|
||||
|
||||
def test_named_tensor(self):
|
||||
return torch.empty([2, 3, 4, 5], device="openreg", names=["N", "C", "H", "W"])
|
||||
|
||||
def test_printing(self):
|
||||
a = torch.ones(20, device="openreg")
|
||||
# Does not crash!
|
||||
@@ -424,6 +433,38 @@ class TestOpenReg(TestCase):
|
||||
self.assertEqual(quantized_tensor.device, torch.device("openreg:0"))
|
||||
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
||||
|
||||
# custom autograd
|
||||
def test_compile_autograd_function_returns_self(self):
|
||||
in_ref = torch.randn(4, device="openreg", requires_grad=True)
|
||||
out_ref = torch.ops.openreg.custom_autograd_fn_returns_self(in_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
in_test = in_ref.detach().clone().requires_grad_(True)
|
||||
# TODO(FFFrog): Need to support inductor for OpenReg first.
|
||||
out_test = torch.compile(backend="aot_eager")(
|
||||
torch.ops.openreg.custom_autograd_fn_returns_self
|
||||
)(in_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(in_ref.grad, in_test.grad)
|
||||
|
||||
@skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
|
||||
def test_compile_autograd_function_aliasing(self):
|
||||
in_ref = torch.randn(4, device="openreg", requires_grad=True)
|
||||
out_ref = torch.ops.openreg.custom_autograd_fn_aliasing(in_ref)
|
||||
out_ref.sum().backward()
|
||||
|
||||
in_test = in_ref.detach().clone().requires_grad_(True)
|
||||
# TODO(FFFrog): Need to support inductor for OpenReg first.
|
||||
out_test = torch.compile(backend="aot_eager")(
|
||||
torch.ops.openreg.custom_autograd_fn_aliasing
|
||||
)(in_test)
|
||||
out_test.sum().backward()
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
self.assertEqual(in_ref.grad, in_test.grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
Reference in New Issue
Block a user