mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Summary: In general, device_ is not very useful in OpKernel. Remove it to avoid misuse. Also, the meaning of `device_` is also ambiguous in the OpKernel. For StaticDispatch kernels, we always call cpu kernel. For C10Kernel, we rely on input tensor's device and dispatcher to determine which device to run on. For ops involves multiple device, e.g. aten._to_copy(device), the meaning of device is ill-defined. Test Plan: CI Rollback Plan: Reviewed By: henryoier, dolpm, kqfu, zhxchen17 Differential Revision: D78704840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158944 Approved by: https://github.com/dolpm
75 lines
1.8 KiB
C++
75 lines
1.8 KiB
C++
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <gtest/gtest.h>
|
|
#include <torch/nativert/executor/ExecutionFrame.h>
|
|
#include <torch/nativert/graph/Graph.h>
|
|
#include <torch/nativert/kernels/C10Kernel.h>
|
|
#include <torch/torch.h>
|
|
|
|
namespace torch::nativert {
|
|
|
|
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
|
|
return a + b;
|
|
}
|
|
|
|
TEST(C10KernelTest, computeInternal) {
|
|
auto registrar = c10::RegisterOperators().op(
|
|
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
|
|
|
|
static constexpr std::string_view source =
|
|
R"(graph(%a, %b):
|
|
%x = test.foo.default(a=%a, b=%b)
|
|
return (%x)
|
|
)";
|
|
|
|
auto graph = stringToGraph(source);
|
|
const auto& nodes = graph->nodes();
|
|
auto it = nodes.begin();
|
|
std::advance(it, 1);
|
|
const Node& node = *it;
|
|
|
|
auto a = at::randn({6, 6, 6});
|
|
auto b = at::randn({6, 6, 6});
|
|
|
|
auto frame = ExecutionFrame(*graph);
|
|
frame.setIValue(graph->getValue("a")->id(), a);
|
|
frame.setIValue(graph->getValue("b")->id(), b);
|
|
|
|
auto kernel = C10Kernel(&node);
|
|
|
|
kernel.computeInternal(frame);
|
|
|
|
at::Tensor expected = a + b;
|
|
EXPECT_TRUE(
|
|
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
|
|
}
|
|
|
|
TEST(ScalarBinaryOpKernelTest, computeInternal) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%a, %b):
|
|
%x = _operator.add(a=%a, b=%b)
|
|
return (%x)
|
|
)";
|
|
|
|
auto graph = stringToGraph(source);
|
|
const auto& nodes = graph->nodes();
|
|
auto it = nodes.begin();
|
|
std::advance(it, 1);
|
|
const Node& node = *it;
|
|
|
|
auto a = 1;
|
|
auto b = 2;
|
|
|
|
auto frame = ExecutionFrame(*graph);
|
|
frame.setIValue(graph->getValue("a")->id(), a);
|
|
frame.setIValue(graph->getValue("b")->id(), b);
|
|
|
|
auto kernel = ScalarBinaryOpKernel(&node);
|
|
|
|
kernel.computeInternal(frame);
|
|
|
|
auto expected = a + b;
|
|
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
|
|
}
|
|
|
|
} // namespace torch::nativert
|