Files
pytorch/test/cpp/jit/test_graph_executor.cpp

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

72 lines
2.4 KiB
C++
Raw Normal View History

#include <gtest/gtest.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace torch {
namespace jit {
TEST(GraphExecutorTest, Basic_CUDA) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto g = build_lstm();
improved TorchScript traceback (#33834) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33834 This changes how we report Tracebacks to make them more clear when there are both serialized and non-serialized ranges. It now looks like: ``` Traceback (most recent call last): File "foo.py", line 25, in <module> s2(a, b) File "/scratch/zdevito/pytorch/torch/nn/modules/module.py", line 550, in __call__ result = self.forward(*input, **kwargs) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/__torch__.py", line 7, in forward x: Tensor, y: Tensor) -> Tensor: return (self).bar(x, y, ) ~~~~~~~~~ <--- HERE def bar(self: __torch__.Moo, x: Tensor, File "code/__torch__.py", line 11, in bar x: Tensor, y: Tensor) -> Tensor: _0 = (self).baz(x, y, ) ~~~~~~~~~ <--- HERE _1 = torch.ones([3], dtype=None, layout=None, device=None, pin_memory=None) return torch.add(_0, _1, alpha=1) File "code/__torch__.py", line 17, in baz x: Tensor, y: Tensor) -> Tensor: return torch.add(x, y, alpha=1) ~~~~~~~~~ <--- HERE Traceback of TorchScript, original code (most recent call last): File "foo.py", line 11, in forward def forward(self, x, y): return self.bar(x, y) ~~~~~~~~ <--- HERE File "foo.py", line 9, in bar def bar(self, x, y): return self.baz(x, y) + torch.ones(3) ~~~~~~~~ <--- HERE File "foo.py", line 7, in baz def baz(self, x, y): return x + y ~~~~~ <--- HERE RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1 ``` It follows Python convension of having the most important information last and reading from the bottom up. Changes: * Moved the error message to the end, to copy Python * Report original traceback separate from serialized traceback * Make sure root functions have names in the interpreter trace. Test Plan: Imported from OSS Differential Revision: D20126136 Pulled By: zdevito fbshipit-source-id: fd01f9985e5d74e04c4d064c02e8bc320f4fac13
2020-03-03 12:24:28 -08:00
GraphExecutor executor(g, "");
auto stack = createStack({input, hx, cx, w_ih, w_hh});
executor.run(stack);
ASSERT_EQ(stack.size(), 2);
auto [r0, r1] = lstm(input, hx, cx, w_ih, w_hh);
ASSERT_TRUE(almostEqual(stack[0].toTensor(), r0));
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
}
TEST(GraphExecutorTest, runAsync_executor) {
/*
TODO: there are some problem with C++ parsing script program involving
fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto module = load(testModelFile);
auto graph = module.get_method("forward").graph();
GraphExecutor graphExecutor(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(std::move(f));
};
std::vector<IValue> stack;
// NOLINTNEXTLINE(modernize-use-emplace)
stack.push_back(module._ivalue());
graphExecutor.runAsync(stack, launcher)->wait();
ASSERT_TRUE(asyncCounter > 0);
}
} // namespace jit
} // namespace torch