Files
pytorch/test/cpp/api/autograd.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

1713 lines
52 KiB
C++
Raw Normal View History

#include <ATen/core/boxing/impl/test_helpers.h>
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
#include <gtest/gtest.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/torch.h>
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
Autograd graphtask trim unnecessary edges (#82544) ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with :x: is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: https://github.com/pytorch/pytorch/issues/56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544 Approved by: https://github.com/soulitzer
2022-08-11 18:50:09 +00:00
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
#include <test/cpp/api/support.h>
using namespace torch::autograd;
using namespace torch::test;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
#define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
#define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
std::string graph_desc(std::shared_ptr<Node> node) {
if (!node) {
return "None";
}
auto result = node->name() + "(";
auto next_edges = node->next_edges();
for (auto& edge : next_edges) {
result += graph_desc(edge.function);
}
return result + ")";
}
Variable simple_fn(const Variable& x, const Variable& y) {
return x + 2 * y + x * y;
}
TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
auto x = at::zeros({}, at::kCPU);
x.requires_grad_();
x.register_hook([](at::TensorBase x) { return; });
auto y = torch::autograd::UndefinedGrad().apply({x});
y[0].backward();
}
TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
auto x = at::zeros({}, at::kCPU);
x.requires_grad_();
x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
auto y = torch::autograd::UndefinedGrad().apply({x});
y[0].backward();
}
TEST(AutogradAPITests, BackwardSimpleTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res.sum()}, {});
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
}
TEST(AutogradAPITests, BackwardTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res}, {torch::ones({2, 2})}, {}, true);
backward({res}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
}
TEST(AutogradAPITests, GradSimpleTest) {
// basic grad
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
}
TEST(AutogradAPITests, GradTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
res.backward(torch::ones({2, 2}), false, true);
Variable x_grad = y + torch::ones({2, 2});
Variable y_grad = x + torch::ones({2, 2}) * 2;
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
Variable grad_sum = 2 * x.grad() + y.grad();
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
}
TEST(AutogradAPITests, GradNonLeafTest) {
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
Variable x = x_init;
Variable y = torch::randn({2, 2}, torch::requires_grad());
Variable grad_output = torch::ones({2, 2});
for (int i = 0; i < 5; ++i) {
auto res = simple_fn(x, y);
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
Variable grad_x_expected = y + torch::ones({2, 2});
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
ASSERT_FALSE(x.grad().defined());
ASSERT_FALSE(y.grad().defined());
x = x + 0.05 * input_grads[0];
}
float val_init = simple_fn(x_init, y).sum().item().toFloat();
float val_final = simple_fn(x, y).sum().item().toFloat();
ASSERT_TRUE(val_final > val_init);
x.backward(grad_output, false, true);
ASSERT_TRUE(x_init.grad().defined());
ASSERT_TRUE(y.grad().defined());
}
TEST(AutogradAPITests, GradUnreachableTest) {
Variable x = torch::ones({1}, torch::requires_grad());
Variable y = torch::ones({1}, torch::requires_grad());
Variable z = x * 2;
Variable w = y * 2;
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
// This is slightly different than the case above, because z doesn't even
// have a grad accumulator allocated.
z = torch::ones({1}, torch::requires_grad());
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
// allow_unused=False, but grads contains None inside, should throw
ASSERT_THROWS_WITH(
grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
}
Fix autograd when `inputs` contains tensors without materialized grad_fn (#51940) Summary: Fixes https://github.com/pytorch/pytorch/issues/39784 At the time the issue was filed, there was only issue (1) below. There are actually now two issues here: 1. We always set all inputs passed in through `inputs` arg as `needed = True` in exec_info. So if we pass in an input that has a grad_fn that is not materialized, we create an entry of exec_info with nullptr as key with `needed = True`. Coincidentally, when we perform simple arithmetic operations, such as "2 * x", one of the next edges of mul is an invalid edge, meaning that its grad_fn is also nullptr. This causes the discovery algorithm to set all grad_fns that have a path to this invalid_edge as `needed = True`. 2. Before the commit that enabled the engine skipped the dummy node, we knew that root node is always needed, i.e., we hardcode `exec_info[&graph_root]=true`. The issue was that this logic wasn't updated after the code was updated to skip the graph root. To address (1), instead of passing in an invalid edge if an input in `inputs` has no grad_fn, we create a dummy grad_fn. This is done in both python and cpp entry points. The alternative is to add logic for both backward() and grad() cases to check whether the grad_fn is nullptr and set needed=false in that case (the .grad() case would be slightly more complicated than the .backward() case here). For (2), we perform one final iteration of the discovery algorithm so that we really know whether we need to execute the graph root. Pull Request resolved: https://github.com/pytorch/pytorch/pull/51940 Reviewed By: VitalyFedyunin Differential Revision: D26369529 Pulled By: soulitzer fbshipit-source-id: 14a01ae7988a8de621b967a31564ce1d7a00084e
2021-02-11 09:20:09 -08:00
TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
// Test that certain nodes are not erroneously executed when an input
// is unreachable. See #39784
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
ADD_FAILURE() << "This node should not be executed!";
return grad_output;
}
};
auto x = torch::randn(1, torch::requires_grad());
auto x1 = torch::randn(1);
auto x2 = MyFunction::apply(x + x1);
auto y = torch::randn(1, torch::requires_grad());
auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
ASSERT_FALSE(grad_res[0].defined());
}
TEST(AutogradAPITests, EmptyInput) {
Variable x = torch::ones({1}, torch::requires_grad());
ASSERT_THROWS_WITH(
grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
}
TEST(AutogradAPITests, RetainGrad) {
auto input = torch::rand({1, 3}, torch::requires_grad());
auto h1 = input * 3;
auto out = (h1 * h1).sum();
{
// Warning when grad is accessed for non-leaf tensor
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
}
// It should be possible to call retain_grad() multiple times
h1.retain_grad();
h1.retain_grad();
{
// If retain_grad is true for a non-leaf tensor,
// there should not be any warning when grad is accessed
WarningCapture warnings;
ASSERT_FALSE(h1.grad().defined());
ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
}
// Gradient should be accumulated
// NOLINTNEXTLINE(bugprone-argument-comment)
out.backward({}, /*keep_graph=*/true);
ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
// NOLINTNEXTLINE(bugprone-argument-comment)
out.backward({}, /*keep_graph=*/true);
ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
{
torch::NoGradGuard no_grad;
input.grad().zero_();
}
// It should be a no-op for leaves
input.retain_grad();
input.retain_grad();
out.backward();
ASSERT_VARIABLE_EQ(input * 18, input.grad());
}
TEST(AutogradAPITests, AnomalyMode) {
// Needs to have backtrace as warning and then throw an error
torch::autograd::DetectAnomalyGuard detect_anomaly;
{
WarningCapture warnings;
auto x = torch::tensor({5.0}, torch::requires_grad());
auto y = x * x;
auto z = y * y;
y += 1;
ASSERT_THROWS_WITH(z.backward(), "inplace");
ASSERT_TRUE(
warnings.str().find("Traceback of forward") != std::string::npos);
}
auto double_backward_produce_nan = [](bool should_throw) {
auto x = torch::tensor({0.0}, torch::requires_grad());
auto y = x.pow(1.5);
auto gr =
// NOLINTNEXTLINE(bugprone-argument-comment)
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
if (should_throw) {
WarningCapture warnings;
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
, "returned nan");
auto msgs = warnings.messages();
ASSERT_EQ(msgs.size(), 2);
ASSERT_TRUE(
msgs[0].find("Traceback of forward call that caused the error") !=
std::string::npos);
ASSERT_TRUE(
msgs[1].find(
"Traceback of forward call that induced the previous calculation") !=
std::string::npos);
} else {
grad({gr[0]}, {x}, {torch::tensor({0.0})});
}
};
double_backward_produce_nan(true);
{
torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
double_backward_produce_nan(false);
{
torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
double_backward_produce_nan(true);
}
}
double_backward_produce_nan(true);
}
TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2) {
ctx->save_for_backward({var1, var2});
return var1 * var2, var1;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {};
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
MyFunction::apply(x, y);
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
TEST(CustomAutogradTest, CustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
int mul,
Variable var2) {
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {
grad_output[0] + grad_output[0] * var2,
Variable(),
grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto res = MyFunction::apply(x, 2, y);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
auto go = torch::ones({}, torch::requires_grad());
res.sum().backward(go, false, true);
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
Autograd graphtask trim unnecessary edges (#82544) ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with :x: is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: https://github.com/pytorch/pytorch/issues/56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544 Approved by: https://github.com/soulitzer
2022-08-11 18:50:09 +00:00
}
TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
torch::autograd::variable_list vars;
for (const at::Tensor& tensor : tensors) {
vars.push_back(tensor);
}
ctx->save_for_backward(vars);
return tensors[0] + tensors[1] + tensors[0] * tensors[1];
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {
grad_output[0] + grad_output[0] * var2,
grad_output[0] + grad_output[0] * var1};
return output;
}
};
at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
torch::autograd::variable_list variables = {x, y};
at::TensorList tensors = variables;
auto res = MyFunction::apply(tensors);
auto go = torch::ones({}, torch::requires_grad());
res.sum().backward(go, false, true);
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
}
Autograd graphtask trim unnecessary edges (#82544) ### Introduction <!-- What did you change and why was it needed? --> Removing unnecessary weight gradient calculation is very important for applications that need high-order derivatives during training. However, this is not supported by the current Autograd engine. For more detail: The backward function of a `matmul` operator (e.g., `linear` `addmm` `mm`), has two matmuls, one for `input gradient` and another for `weight gradient`. For a typical neural network (nn) with a few linear layers and activation functions, if the user calls `torch.autograd.grad()` to calculate the derivative of the nn output `y` w.r.t the nn input `x`, only the `input gradient` of the `matmul` operator is needed, and the `weight gradient` is discarded. However, the current PyTorch autograd engine will always calculate the `weight gradient` if `weight` requires gradient (the calculation of the high-order derivative is performed during training). The figure attached shows the autograd graph of the following code snippet: ```py y = torch.nn.functional.linear(x, weight, bias) y = y.pow(2) # first order derivative y__x, = torch.autograd.grad(y, x, grad_outputs=grad_outputs, create_graph=True) # first order derivative y__x__x, = torch.autograd.grad(y__x, x, grad_outputs=grad_outputs, create_graph=True) ``` The path with :x: is not needed when calculating derivatives. <img width="50%" alt="image" src="https://user-images.githubusercontent.com/9999318/182018117-719c5a23-bcc6-4a63-8e8d-1bca3ebda2e3.png"> ### Issue <!-- Link to Issue ticket or RFP --> Related issue: https://github.com/pytorch/pytorch/issues/56500 ### Method When calling `torch.autograd.grad`, `exec_info_` is created for each GraphTask, which allows filtering paths on the graph that are not needed. However, when the GraphTask calls into the node, the node still does not know whether the edges are needed or not. In the case of matmul, `weight.requires_grad is True` so the weight gradient is always calculated. Following https://github.com/pytorch/pytorch/issues/56500#issuecomment-825694656, this PR passes the graph task's thread_local `exec_info_` into the node, so it could trim unnecessary edges during `torch.autograd.grad` calls. ### Benchmark Benchmark script: https://gist.github.com/yueyericardo/24158433a2021c51eeef9c3e2722df99 Benchmark result: 6 hidden layers, batch size 10000, on A100 FP32 result | hessian benchmark | FP32 (before) | FP32 (After) | FP32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 55.658 ms | 29.392 ms (1.90X) | 29.547 ms (1.90X) | | Linear + ReLU (with backward) | 81.173 ms | 54.917 ms (1.47X) | 68.988 ms (1.18X) | TF32 result | hessian benchmark | TF32 (before) | TF32 (after) | TF32 (Functorch v0.1.1) | | ----------------------------- | ------------- | ----------------- | ----------------------- | | Linear + ReLU (no backward) | 19.801 ms | 11.259 ms (1.76X) | 10.754 ms (1.84X) | | Linear + ReLU (with backward) | 29.167 ms | 20.466 ms (1.42X) | 22.784 ms (1.28X) | For FP32 result, we could get 1.9X speed up for hessian calculation, and 1.47X speed up during training, which is even faster than functorch `vmap(jacfwd(jacrev` implementation. (functorch has performance regression on v0.2.0, https://github.com/pytorch/functorch/issues/989, so we are using v0.1.1 for benchmark) @zou3519 does functorch also includes similar optimizations during hessian calculation? If not, what do we need to do so the functorch could also benefit from this PR? ### Testing <!-- How did you test your change? --> - [x] we need to figure out a way for unittest ### Thanks Thanks for the great blog: [How Computational Graphs are Executed in PyTorch | PyTorch](https://pytorch.org/blog/how-computational-graphs-are-executed-in-pytorch/) cc @zasdfgbnm @albanD Pull Request resolved: https://github.com/pytorch/pytorch/pull/82544 Approved by: https://github.com/soulitzer
2022-08-11 18:50:09 +00:00
TEST(CustomAutogradTest, GraphTaskTrimEdges) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(
AutogradContext* ctx,
Variable var1,
Variable var2,
int mul,
bool needs_input1_grad,
bool needs_input2_grad) {
// setup the expected should and should not compute idx
ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul * var2 + var1 * var2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Test `needs_input_grad` method is working correctly.
// We have to test this within the backward function.
auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
IndexRange var1_idx = {0, 1};
IndexRange var2_idx = {1, 2};
EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
EXPECT_EQ(
ctx->needs_input_grad({var1_idx, var2_idx}),
needs_input1_grad || needs_input2_grad);
// calculate gradients
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
Variable grad_var1, grad_var2;
if (ctx->needs_input_grad(0)) {
grad_var1 = grad_output[0] + grad_output[0] * var2;
}
if (ctx->needs_input_grad(1)) {
grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
}
variable_list output = {
grad_var1,
grad_var2,
Variable(),
Variable(),
Variable(),
};
return output;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto go = torch::ones_like(x);
Variable out;
// grad_x
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ false);
auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
// grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ false,
/* needs_input2_grad= */ true);
auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
// grad_x and grad_y
out = MyFunction::apply(
x,
y,
2,
/* needs_input1_grad= */ true,
/* needs_input2_grad= */ true);
auto grads = torch::autograd::grad({out}, {x, y}, {go});
ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var1) {
return var1;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0] * 2};
}
};
Variable x(torch::ones(1, torch::requires_grad()));
MyFunction::apply(x).backward(torch::ones(1), true, true);
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
TEST(CustomAutogradTest, FunctionReturnsUndefined) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var * 2;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
at::Tensor undefined_tensor;
return {undefined_tensor};
}
};
auto x = torch::ones(1, torch::requires_grad());
MyFunction::apply(x).backward();
ASSERT_FALSE(x.grad().defined());
MyFunction::apply(x.pow(2)).backward();
ASSERT_FALSE(x.grad().defined());
MyFunction::apply(x).sum().backward();
ASSERT_FALSE(x.grad().defined());
ASSERT_FALSE(torch::autograd::grad(
{MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
.defined());
}
TEST(CustomAutogradTest, MaterializeGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
return grad_output;
}
};
auto x = torch::ones(1, torch::requires_grad());
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}
TEST(CustomAutogradTest, DontMaterializeGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable var) {
ctx->set_materialize_grads(false);
return var;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
EXPECT_FALSE(grad_output[0].defined());
return grad_output;
}
};
auto x = torch::ones(1, torch::requires_grad());
UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
TEST(CustomAutogradTest, NoGradCustomFunction) {
// Custom Function should respect grad mode
struct MyOp : public Function<MyOp> {
static Variable forward(AutogradContext* ctx, Variable x) {
return x + 1;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
static variable_list backward(AutogradContext* ctx, variable_list dy) {
return dy;
}
};
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
auto x = torch::ones({5, 5}, torch::requires_grad());
{
at::NoGradGuard no_grad;
auto y = MyOp::apply(x);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
ASSERT_FALSE(y.requires_grad());
}
}
TEST(CustomAutogradTest, MarkDirty) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable v) {
// Change the value inplace
auto v_data = v.data_ptr<float>();
v_data[0] = 2;
ctx->mark_dirty({v});
return v;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {(grad_output[0] * 2.0)};
}
};
// Clone here because modifying leafs inplace is not allowed
auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
auto version_before = x._version();
auto out = MyFunction::apply(x);
auto version_after = x._version();
ASSERT_TRUE(version_after >= (version_before + 1));
out.sum().backward();
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
TEST(CustomAutogradTest, MarkNonDifferentiable) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable v) {
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
Variable output = v > 0;
ctx->mark_non_differentiable({output});
return output;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {(grad_output[0] * 0.0)};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto mask = MyFunction::apply(x);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
ASSERT_FALSE(mask.requires_grad());
auto y = x.masked_fill(mask, 0);
y.sum().backward();
}
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext* ctx, Variable input) {
Variable a = input + 1;
Variable b = input + 2;
ctx->mark_non_differentiable({a});
return {a, b};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
return {grad_b};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto out = MyFunction::apply(x);
ASSERT_FALSE(out[0].requires_grad());
ASSERT_TRUE(out[1].requires_grad());
out[1].sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable input) {
auto output = input.clone();
ctx->mark_non_differentiable({output});
return output;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_outputs) {
return {};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto r = MyFunction::apply(x * x);
(r * x).sum().backward();
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:25:23 -07:00
TEST(CustomAutogradTest, ReturnLeafInplace) {
struct Inplace : public Function<Inplace> {
static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
ctx->mark_dirty({a});
return {a.add_(b), b + 2};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0], grad_output[0] + grad_output[1]};
}
};
Variable x = torch::randn({5, 5});
Variable y = torch::randn({5, 5}, torch::requires_grad());
auto out = Inplace::apply(x, y);
auto& q = out[0];
ASSERT_TRUE(torch::equal(q, x));
ASSERT_TRUE(q.requires_grad());
q.sum().backward();
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
struct DoubleInplace : public Function<DoubleInplace> {
static variable_list forward(AutogradContext* ctx, Variable x) {
x.mul_(2);
ctx->mark_dirty({x});
return {x, x};
}
static variable_list backward(
AutogradContext* ctsx,
variable_list grad_outputs) {
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
ASSERT_THROWS_WITH(
DoubleInplace::apply(x), "leaf Variable that requires grad");
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
// output");
auto out = DoubleInplace::apply(x.clone());
ASSERT_TRUE(torch::equal(out[0], out[1]));
}
TEST(CustomAutogradTest, ReturnDuplicate) {
struct DoubleDuplicate : public Function<DoubleDuplicate> {
static variable_list forward(AutogradContext* ctx, Variable x) {
auto output = x * 2;
return {output, output};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_outputs) {
return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
}
};
auto x = torch::randn({5, 5}, torch::requires_grad());
auto out = DoubleDuplicate::apply(x);
ASSERT_TRUE(torch::equal(out[0], out[1]));
}
TEST(CustomAutogradTest, SaveEmptyForBackward) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable input) {
ctx->save_for_backward({Variable(), input, Variable()});
return input * input;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
auto saved = ctx->get_saved_variables();
EXPECT_FALSE(saved[0].defined());
EXPECT_FALSE(saved[2].defined());
return {saved[1] * 2 * grad_output[0]};
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
auto y = MyFunction::apply(x);
y.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
}
TEST(CustomAutogradTest, InvalidGradients) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext* ctx, Variable x) {
return x * 2;
}
static variable_list backward(
AutogradContext* ctsx,
variable_list grad_outputs) {
return {
torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
}
};
auto input1 =
torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
ASSERT_THROWS_WITH(
MyFunction::apply(input1).sum().backward(), "expected shape");
auto input2 =
torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
}
TEST(CustomAutogradTest, NoGradInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(
AutogradContext*,
variable_list grad_outputs) {
return grad_outputs;
}
};
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y;
{
at::NoGradGuard no_grad;
y = MyFunction::apply(x);
}
ASSERT_TRUE(x.requires_grad());
ASSERT_FALSE(y.grad_fn());
}
TEST(CustomAutogradTest, TooManyGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable input) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
grad_output.insert(grad_output.end(), {Variable(), Variable()});
return grad_output;
}
};
}
TEST(CustomAutogradTest, DepNoGrad) {
struct F1 : public Function<F1> {
static variable_list forward(AutogradContext* ctx, Variable input) {
auto out = torch::randn(input.sizes());
ctx->mark_non_differentiable({out});
return {input, out};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
return {grad_output[0]};
}
};
struct F2 : public Function<F2> {
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
return {grad_output[0], Variable()};
}
};
auto x = torch::randn(5, torch::requires_grad());
auto out = F1::apply(x);
Variable &a = out[0], &b = out[1];
b = b + 1; // Separate F1 and F2 by another operation
ASSERT_TRUE(a.requires_grad());
ASSERT_FALSE(b.requires_grad());
auto c = F2::apply(a, b);
c.backward(torch::ones(c.sizes()), false, false);
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
}
TEST(CustomAutogradTest, Reentrant) {
static Variable y_data = torch::randn({2, 2});
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext* ctx, Variable input) {
Variable output;
{
at::AutoGradMode enable_grad(true);
auto x = make_variable(input.tensor_data(), true);
auto y = make_variable(y_data.tensor_data(), true);
output = x * y;
ctx->saved_data["x"] = x;
ctx->saved_data["y"] = y;
ctx->saved_data["output_var"] = output;
}
return output.detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
{
at::AutoGradMode enable_grad(true);
auto out = ctx->saved_data["output_var"].toTensor();
out.sum().backward();
}
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
}
};
auto x = torch::randn({2, 2}, torch::requires_grad());
auto out = Reenter::apply(x);
out.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), y_data);
}
// NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
// the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest, DeepReentrant) {
struct DeepReenter : public Function<DeepReenter> {
static Variable forward(AutogradContext* ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
// This should not stack overflow
Fix bugs in torch::tensor constructor (#28523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28523 New features: 1. Previously, `torch::tensor({true, false, true})` throws `"tensor_cpu" not implemented for 'Bool'`. After this PR, it produces the correct bool tensor, matching the Python API behavior. 2. Tensors with zero-size dimensions are now supported, e.g. `torch::tensor({{}, {}})` produces a tensor with sizes `{2, 0}`, matching the Python API behavior. BC-breaking bug fixes: 1. Previously, `torch::tensor({{1}, {2}})` produces a tensor of sizes `{2}`. After this PR, it produces a tensor of sizes `{2, 1}`, matching the Python API behavior. 2. Fixed semantics of `torch::tensor(1.1)`: it now returns a 0-dim tensor instead of a 1-dim tensor, matching the Python API behavior. 3. Previously, when passed a non-dtype `TensorOptions` to the `torch::tensor` constructor, it always produces a tensor of dtype `float`. After this PR, it produces tensor of different dtypes based on the dtype of the braced-init-list, matching the behavior of the no-options case. ```cpp // Previously: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float // Now: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double // As comparison, currently: torch::tensor({1, 2, 3}).dtype() -> int torch::tensor({{1, 2, 3}}).dtype() -> int torch::tensor({1., 2., 3.}).dtype() -> double torch::tensor({{1., 2., 3.}}).dtype() -> double ``` Notes: 1. From now on, the behavior of `at::tensor(scalar_value)` (which produces a 1-dim tensor) would be different from `torch::tensor(scalar_value)` (which produces a 0-dim tensor). I will fix the behavior of `at::tensor(scalar_value)` in a follow-up PR. 2. From now on, the behavior of `at::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a `float` tensor) would be different from `torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a an `int` tensor). I will fix this behavior of `at::tensor` constructor in a follow-up PR. Context for the changes in this PR: The motivation comes from fixing the "`torch::tensor({{1}, {2}})` gives tensor of wrong sizes" bug - in order to fix it, I have to move the handling of `at::ArrayRef` and `std::vector` into `InitListTensor` (see below on why we need to do this) and renamed `InitListTensor` to `TensorDataContainer`. After such changes, support for bool values comes out of the box without extra effort, and support for tensors with zero-size dimensions only requires adding a default constructor for `TensorDataContainer`, so I added those two in this PR. For the semantic change of `torch::tensor(1.1)`, it's actually more effort to preserve the original wrong behavior (i.e. we need to check the sizes of the tensor converted from `TensorDataContainer` and reshape any scalar tensor to a 1-D tensor). I think preserving the original wrong behavior doesn't give us much value, and since the above changes naturally fix the problem, we should just start using the right behavior instead. For the "constructor with non-dtype options behavior" fix, the code looks simpler and easier to reason about with the fix, so I included it in this PR. -------- Why we need to move the handling of `at::ArrayRef` and `std::vector` into `TensorDataContainer`: `torch::tensor({{1}, {2}})` can match this function overload: `torch::tensor(at::ArrayRef<int> values)`, because `{1}` and `{2}` can be treated as a list-initialization of an `int` value. However, this will produce a Tensor with sizes `{2}`, but we actually want a Tensor with sizes `{2, 1}`. In order to avoid matching this function overload, we removed the function overload and moved the ability to convert `at::ArrayRef<T>` (and similarly `std::vector<T>`) into `TensorDataContainer`, and since for braced-init-list the `TensorDataContainer(std::initializer_list<TensorDataContainer>)` constructor is always preferred over all other constructors, it will take the `std::initializer_list` path, and all is good. Test Plan: Imported from OSS Differential Revision: D18234625 Pulled By: yf225 fbshipit-source-id: 0f3f6912e82e2117d2103e31b74e7e97baaa8693
2019-10-31 12:51:18 -07:00
auto v =
torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
DeepReenter::apply(v).sum().backward();
}
TEST(CustomAutogradTest, ReentrantPriority) {
static std::vector<int> order;
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(AutogradContext*, variable_list grad) {
order.push_back(0);
return grad;
}
};
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext* ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
order.push_back(1);
if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
Fix bugs in torch::tensor constructor (#28523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28523 New features: 1. Previously, `torch::tensor({true, false, true})` throws `"tensor_cpu" not implemented for 'Bool'`. After this PR, it produces the correct bool tensor, matching the Python API behavior. 2. Tensors with zero-size dimensions are now supported, e.g. `torch::tensor({{}, {}})` produces a tensor with sizes `{2, 0}`, matching the Python API behavior. BC-breaking bug fixes: 1. Previously, `torch::tensor({{1}, {2}})` produces a tensor of sizes `{2}`. After this PR, it produces a tensor of sizes `{2, 1}`, matching the Python API behavior. 2. Fixed semantics of `torch::tensor(1.1)`: it now returns a 0-dim tensor instead of a 1-dim tensor, matching the Python API behavior. 3. Previously, when passed a non-dtype `TensorOptions` to the `torch::tensor` constructor, it always produces a tensor of dtype `float`. After this PR, it produces tensor of different dtypes based on the dtype of the braced-init-list, matching the behavior of the no-options case. ```cpp // Previously: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float // Now: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double // As comparison, currently: torch::tensor({1, 2, 3}).dtype() -> int torch::tensor({{1, 2, 3}}).dtype() -> int torch::tensor({1., 2., 3.}).dtype() -> double torch::tensor({{1., 2., 3.}}).dtype() -> double ``` Notes: 1. From now on, the behavior of `at::tensor(scalar_value)` (which produces a 1-dim tensor) would be different from `torch::tensor(scalar_value)` (which produces a 0-dim tensor). I will fix the behavior of `at::tensor(scalar_value)` in a follow-up PR. 2. From now on, the behavior of `at::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a `float` tensor) would be different from `torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a an `int` tensor). I will fix this behavior of `at::tensor` constructor in a follow-up PR. Context for the changes in this PR: The motivation comes from fixing the "`torch::tensor({{1}, {2}})` gives tensor of wrong sizes" bug - in order to fix it, I have to move the handling of `at::ArrayRef` and `std::vector` into `InitListTensor` (see below on why we need to do this) and renamed `InitListTensor` to `TensorDataContainer`. After such changes, support for bool values comes out of the box without extra effort, and support for tensors with zero-size dimensions only requires adding a default constructor for `TensorDataContainer`, so I added those two in this PR. For the semantic change of `torch::tensor(1.1)`, it's actually more effort to preserve the original wrong behavior (i.e. we need to check the sizes of the tensor converted from `TensorDataContainer` and reshape any scalar tensor to a 1-D tensor). I think preserving the original wrong behavior doesn't give us much value, and since the above changes naturally fix the problem, we should just start using the right behavior instead. For the "constructor with non-dtype options behavior" fix, the code looks simpler and easier to reason about with the fix, so I included it in this PR. -------- Why we need to move the handling of `at::ArrayRef` and `std::vector` into `TensorDataContainer`: `torch::tensor({{1}, {2}})` can match this function overload: `torch::tensor(at::ArrayRef<int> values)`, because `{1}` and `{2}` can be treated as a list-initialization of an `int` value. However, this will produce a Tensor with sizes `{2}`, but we actually want a Tensor with sizes `{2, 1}`. In order to avoid matching this function overload, we removed the function overload and moved the ability to convert `at::ArrayRef<T>` (and similarly `std::vector<T>`) into `TensorDataContainer`, and since for braced-init-list the `TensorDataContainer(std::initializer_list<TensorDataContainer>)` constructor is always preferred over all other constructors, it will take the `std::initializer_list` path, and all is good. Test Plan: Imported from OSS Differential Revision: D18234625 Pulled By: yf225 fbshipit-source-id: 0f3f6912e82e2117d2103e31b74e7e97baaa8693
2019-10-31 12:51:18 -07:00
auto a = MyFunction::apply(
torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
auto b = Reenter::apply(
torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
auto v = a * b;
v.backward();
// All the reentrant tasks should be prioritized over the MyFunction backward
// task.
ASSERT_EQ(order.size(), 10);
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
ASSERT_EQ(order.back(), 0);
// Clear static variable in case test get executed in a loop
order.clear();
}
TEST(CustomAutogradTest, Hooks) {
Variable x = torch::ones({5, 5}, torch::requires_grad());
Variable y = torch::ones({5, 5}) * 4;
y.set_requires_grad(true);
int counter = 0;
std::function<void(int, Variable)> bw_hook(
[&counter](int inc, Variable grad) { counter += inc; });
Variable z = x * x + x * 2 + x * y + y;
x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
auto hook_1 =
z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 1);
auto hook_2 =
z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 4);
z.remove_hook(hook_2);
z.backward(torch::ones({5, 5}), true, true);
ASSERT_EQ(counter, 5);
std::function<Variable(Variable)> bw_hook_modify(
[](Variable grad) { return grad.mul(2); });
z.remove_hook(hook_1);
z.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5, 5}), true, false);
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
y.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5, 5}), false, false);
ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}
TEST(CustomAutogradTest, HooksInplace) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.mul_(2);
a.register_hook(hook2);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
}
TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
int hook1_count = 0;
auto hook1 = ([&hook1_count](Variable grad) {
hook1_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook2_count = 0;
auto hook2 = ([&hook2_count](Variable grad) {
hook2_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
});
int hook3_count = 0;
auto hook3 = ([&hook3_count](Variable grad) {
hook3_count++;
ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
});
a.register_hook(hook1);
a.retain_grad();
a.register_hook(hook2);
a.mul_(2);
a.mul_(2);
a.register_hook(hook3);
auto out = (a + 1).sum();
out.backward();
ASSERT_EQ(hook1_count, 1);
ASSERT_EQ(hook2_count, 1);
ASSERT_EQ(hook3_count, 1);
ASSERT_TRUE(a.retains_grad());
ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
}
TEST(CustomAutogradTest, HookNone) {
struct NoneGradientFunction : public Function<NoneGradientFunction> {
static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
return {x, y};
}
static variable_list backward(AutogradContext* ctx, variable_list grad) {
return {grad[0], Variable()};
}
};
bool was_called = false;
auto hook = ([&was_called](Variable grad) {
ASSERT_TRUE(grad.defined());
was_called = true;
});
auto x = torch::randn({5, 5}, torch::requires_grad());
auto y = torch::randn({5, 5});
auto out = NoneGradientFunction::apply(x, y);
Variable rx = x[0], ry = x[1];
rx.register_hook(hook);
ry.register_hook(hook);
(rx + ry).sum().backward();
ASSERT_TRUE(was_called);
}
TEST(CustomAutogradTest, BackwardWithInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;
z.backward(torch::ones({5, 5}), false, false, {x});
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_FALSE(y.grad().defined());
}
TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x + x * y + y * y;
Variable x_grad_expected = 2 * x + y;
Variable y_grad_expected = x + 2 * y;
ASSERT_THROWS_WITH(
z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
"cannot be empty");
}
TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
Variable x = torch::randn({5, 5}, torch::requires_grad());
Variable y = torch::randn({5, 5}, torch::requires_grad());
Variable z = x * x;
Variable w = y * z + x * y + y * y;
Variable x_grad_expected = 2 * x * y + y;
Variable z_grad_expected = y;
w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
ASSERT_FALSE(y.grad().defined());
}
TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
c10::WarningUtils::WarnAlways guard(true);
torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
auto z = x * x;
{
WarningCapture warnings;
z.backward(torch::ones({5, 5}), std::nullopt, true);
ASSERT_TRUE(
warnings.str().find("Using backward() with create_graph=True") !=
std::string::npos);
}
{
WarningCapture warnings;
torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
ASSERT_TRUE(
warnings.str().find("Using backward() with create_graph=True") !=
std::string::npos);
}
}
/**
* Tests for AutogradNotImplementedFallback
* - Check that we created the NotImplemented kernel when inputs require grad
* but when no inputs require grad, we should not create this node
* - check_inplace logic
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
* - view ops
* - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
* test non-NDEBUG builds.
* - tensorlist input and output
* - multiple outputs / non-tensor output
* - rebase_history vs set_history
*/
namespace {
torch::Tensor inplace_op(
const torch::Tensor& self,
const torch::Tensor& other) {
return self.add_(other);
}
std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
const torch::Tensor& self,
const torch::Tensor& other) {
other.add_(self);
self.add_(other);
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}
std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
const torch::Tensor& self,
const torch::Tensor& other) {
// This is not allowed. We test below that this calling into the boxed kernel
// will raise an error
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
return std::tuple<torch::Tensor, torch::Tensor>(self, other);
}
std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
const torch::Tensor& self,
const torch::Tensor& other) {
// This is not allowed. We test below that this calling into the boxed kernel
// will raise an error
return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
}
int64_t ret_single_non_tensor(
const torch::Tensor& self,
const torch::Tensor& other) {
return 12;
}
torch::Tensor opt_op(
const torch::Tensor& self,
const std::optional<at::Tensor>& other) {
if (other.has_value()) {
return self + other.value();
} else {
return self.clone();
}
}
torch::Tensor my_custom_op(
const torch::Tensor& self,
const torch::Tensor& other) {
return self + other;
}
std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
const torch::Tensor& self,
const torch::Tensor& other) {
auto a = self - other;
auto b = self + other;
return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
}
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
torch::Tensor view_op(const torch::Tensor& self) {
return self.alias();
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
}
torch::Tensor view_op_with_extra_arg(
const torch::Tensor& self,
const torch::Tensor& other) {
return self.alias();
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
}
std::vector<torch::Tensor> ret_tensor_vector_view(
const torch::Tensor& self,
const torch::Tensor& other) {
return {self.alias(), self.alias()};
}
std::vector<at::Tensor> ret_tensor_vector(
const torch::Tensor& self,
const torch::Tensor& other) {
std::vector<at::Tensor> out;
out.push_back(self + other);
out.push_back(self - other);
return out;
}
torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
const auto& res = self.clone();
for (const auto& t : other) {
res.add_(t);
}
return res;
}
#define REGISTER_TEST_OP(name, schema, fn) \
auto m = MAKE_TORCH_LIBRARY(_test); \
m.def(schema); \
auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
m_autograd.impl( \
name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
m_inplaceorview.impl( \
name, \
c10::DispatchKey::ADInplaceOrView, \
autogradNotImplementedInplaceOrViewFallback());
template <typename F>
void assertBasicChecks(F op) {
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
auto c = torch::tensor({1.}, {torch::kFloat32});
// If any inputs require grad,
auto out1 = op(a, b);
ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
// # Should not have grad_fn if none require grad
auto out2 = op(b, c);
ASSERT_THROWS_WITH(
out2.backward(),
"element 0 of tensors does not require grad and does not have a grad_fn");
// TODO: Forward AD Tests?
}
} // namespace
TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
REGISTER_TEST_OP(
"ret_single_non_tensor",
"_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
ret_single_non_tensor);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_single_non_tensor", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
}
TEST(TestAutogradNotImplementedFallback, InplaceOp) {
REGISTER_TEST_OP(
"inplace_op",
"_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
inplace_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
// Check in-place
ASSERT_THROWS_WITH(
op(a, b),
"a leaf Variable that requires grad is being used in an in-place operation");
op(b, a);
a = a.clone();
b = b.clone();
auto c = op(a, b);
ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
// Test in-place on view
auto base =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto view = base.view(-1);
auto t = torch::tensor({1.}, {torch::kFloat32});
torch::Tensor v_nograd;
{
c10::NoGradGuard guard;
v_nograd = base.view(-1);
op(v_nograd, t);
}
ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
ASSERT_THAT(
op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
}
TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
REGISTER_TEST_OP(
"two_arg_inplace_op",
"_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
two_arg_inplace_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::two_arg_inplace_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
// Both are modified in-place!
ASSERT_THROWS_WITH(
op(a, b),
"a leaf Variable that requires grad is being used in an in-place operation");
ASSERT_THROWS_WITH(
op(b, a),
"a leaf Variable that requires grad is being used in an in-place operation");
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
auto c =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto d =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto saved_version_c = c._version();
auto saved_version_d = d._version();
op(c, d);
ASSERT_NE(c._version(), saved_version_c);
ASSERT_NE(d._version(), saved_version_d);
}
TEST(TestAutogradNotImplementedFallback, OptOp) {
REGISTER_TEST_OP(
"opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
auto op = [&](const torch::Tensor& _1,
const std::optional<torch::Tensor>& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const std::optional<torch::Tensor>&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
}
TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
REGISTER_TEST_OP(
"my_custom_op",
"_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
my_custom_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
REGISTER_TEST_OP(
"ret_tuple_non_tensor",
"_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
ret_tuple_non_tensor);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tuple_non_tensor", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
auto out = callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor, int64_t>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
auto [out0, out1, out2] = std::move(out);
return out0;
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, ViewOp) {
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
REGISTER_TEST_OP(
"view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
auto opHandle =
c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
auto op = [&](const torch::Tensor& _1) {
return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
};
auto b = torch::tensor({1.}, {torch::kFloat32});
auto v = op(b);
ASSERT_TRUE(v.is_view());
ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
auto b1 =
torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
auto v1 = op(b1);
ASSERT_TRUE(v1.is_view());
ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
// Test inplace on view
auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
// raise on rebase_history when it refreshes grad_fn
ASSERT_THROWS_WITH(
v1.add_(t), "which does not have a derivative implemented is forbidden");
// base should not be aware of the views, so this is still okay
b1.add_(t);
ASSERT_THROWS_WITH(
v1.grad_fn(),
"which does not have a derivative implemented is forbidden");
}
TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
REGISTER_TEST_OP(
"view_op_with_extra_arg",
"_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
view_op_with_extra_arg);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::view_op_with_extra_arg", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
torch::Tensor,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
assertBasicChecks(op);
Add InplaceOrView boxed kernel (#63878) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63878 See https://github.com/pytorch/pytorch/issues/64407, https://github.com/pytorch/pytorch/issues/62032 for context: In this PR: - Add boxed kernel by replicating `gen_inplace_or_view`'s logic that is ONLY for use with the Autograd not-implemented kernel - Unlike `gen_inplace_or_view` we always pass a view_func to as_view in order to ensure that an "derivative is not implemented" error is raised even if an in-place update is performed on the view. Without the `view_func`, the CopySlice + AsStridedBackward nodes would replace the NotImplemented node. - This limitation makes it impossible to use this node for general use - view relationship must be between first input (must be tensor) and first output (may be tensor or vec of tensor) - do not support non-differentiable views (_values, _indices, view.dtype) - view relationship is always fw and bw differentiable - Adds the macro `#define REGISTER_AUTOGRAD_NOT_IMPLEMENTED_FALLBACK(ns, op)` to be the interface for this feature: - static initialization can be slowed down(? not measured) if there are many registrations, because each line translates to 2 library calls but the workaround is just to manually use the two functions `AutogradNotImplementedFallback` and `ADInplaceOrViewFallback` and call `m.impl`. - Adds testing: - for views: view relationship created - performing in-place operation on the view, raises properly - trying to create two view relationships is not allowed, - single view relationship but not first input/first output should error - view relation created properly for tensor vector output - for in-place: - version count bump - triggers rebase_history - multiple mutations is okay and also updates version counter - TODO (follow up): Update tutorials for adding third-party operators (and document the above limitations) - TODO (follow up): Look at torch-audio/torch-vision and identify places where this can simplify existing code EDIT: Made it more clear what is introduced in this PR and moved some more contextual stuff into the issue itself Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30901714 Pulled By: soulitzer fbshipit-source-id: 48de14c28be023ff4bd31b7ea5e7cba88aeee04c
2021-10-12 18:54:29 -07:00
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({2.}, {torch::kFloat32});
auto out1 = op(a, b);
ASSERT_TRUE(out1.is_view());
ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
REGISTER_TEST_OP(
"ret_tensor_vector_view",
"_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
ret_tensor_vector_view);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tensor_vector_view", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::vector<at::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({1.}, {torch::kFloat32});
auto out = op(a, b);
ASSERT_TRUE(out[0].is_view());
ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
ASSERT_TRUE(out[1].is_view());
ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
}
TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
REGISTER_TEST_OP(
"two_pairs_of_view_op",
"_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
two_pairs_of_view_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::two_pairs_of_view_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_THROWS_WITH(
op(a, b),
"Expected only a single output in the operator schema to have a non-write alias annotation");
}
TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
REGISTER_TEST_OP(
"non_first_view_op",
"_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
non_first_view_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::non_first_view_op", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::tuple<torch::Tensor, torch::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
auto b = torch::tensor({1.}, {torch::kFloat32});
ASSERT_THROWS_WITH(
op(a, b), "can only create view relationships between the first");
}
TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
REGISTER_TEST_OP(
"ret_tensor_vector",
"_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
ret_tensor_vector);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::ret_tensor_vector", "");
auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
return callOpUnboxed<
std::vector<at::Tensor>,
const torch::Tensor&,
const torch::Tensor&>(opHandle, _1, _2)[0];
};
assertBasicChecks(op);
}
TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
REGISTER_TEST_OP(
"tensorlist_op",
"_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
tensorlist_op);
auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
"_test::tensorlist_op", "");
auto op = [&](torch::Tensor _1, at::TensorList _2) {
return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
opHandle, _1, _2);
};
auto a = torch::tensor({1.}, {torch::kFloat32});
auto b = torch::tensor({1.}, {torch::kFloat32});
auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
std::vector<torch::Tensor> vec = {b, c};
auto out = op(a, vec);
ASSERT_THROWS_WITH(
torch::autograd::grad({out}, {vec[0]}),
"element 0 of the input tensors does not require grad");
ASSERT_THROWS_WITH(
torch::autograd::grad({out}, {vec[1]}), "is not implemented");
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
}
static std::string test_format_error(const std::string& s) {
return s;
}
TEST(TestAutogradUtils, ValidateOutputsReduce) {
auto input = torch::ones({}, {torch::kFloat32});
auto grad = torch::ones({2, 3}, {torch::kFloat32});
std::vector<std::optional<InputMetadata>> input_metadata;
input_metadata.emplace_back(InputMetadata(input));
std::vector<torch::Tensor> grads;
grads.emplace_back(grad);
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
ASSERT_TRUE(at::allclose(grads[0], grad.sum()));
}
TEST(TestAutogradUtils, ValidateOutputsBasic) {
auto input = torch::zeros({2, 3}, {torch::kFloat32});
auto grad = torch::ones({2, 3}, {torch::kFloat32});
std::vector<std::optional<InputMetadata>> input_metadata;
input_metadata.emplace_back(InputMetadata(input));
std::vector<torch::Tensor> grads;
grads.emplace_back(grad);
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
ASSERT_TRUE(at::allclose(grad, torch::ones({2, 3})));
}
// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
// test_save_output_nr
// test_free_deep_graph_pyfunction
// test_naughty_anomaly_access
// test_naughty_autograd-function_stashing_ctx
// test_custom_autograd_repeated_grad_grad
// test_return_leaf
// test_anomaly_detect_nan
// test_no_grad_copy