Files
pytorch/test/cpp/nativert/test_c10_kernel.cpp
Sherlock Huang fb067de550 [NativeRT] Remove device_ member from OpKernel base class (#158944)
Summary:
In general, device_ is not very useful in OpKernel.  Remove it to avoid misuse.

Also, the meaning of `device_` is also ambiguous in the OpKernel.
For StaticDispatch kernels, we always call cpu kernel.
For C10Kernel, we rely on input tensor's device and dispatcher to determine which device to run on.
For ops involves multiple device, e.g. aten._to_copy(device), the meaning of device is ill-defined.

Test Plan:
CI

Rollback Plan:

Reviewed By: henryoier, dolpm, kqfu, zhxchen17

Differential Revision: D78704840

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158944
Approved by: https://github.com/dolpm
2025-07-24 09:21:37 +00:00

75 lines
1.8 KiB
C++

#include <ATen/core/op_registration/op_registration.h>
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/kernels/C10Kernel.h>
#include <torch/torch.h>
namespace torch::nativert {
at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
return a + b;
}
TEST(C10KernelTest, computeInternal) {
auto registrar = c10::RegisterOperators().op(
"test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = test.foo.default(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
auto a = at::randn({6, 6, 6});
auto b = at::randn({6, 6, 6});
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = C10Kernel(&node);
kernel.computeInternal(frame);
at::Tensor expected = a + b;
EXPECT_TRUE(
torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
}
TEST(ScalarBinaryOpKernelTest, computeInternal) {
static constexpr std::string_view source =
R"(graph(%a, %b):
%x = _operator.add(a=%a, b=%b)
return (%x)
)";
auto graph = stringToGraph(source);
const auto& nodes = graph->nodes();
auto it = nodes.begin();
std::advance(it, 1);
const Node& node = *it;
auto a = 1;
auto b = 2;
auto frame = ExecutionFrame(*graph);
frame.setIValue(graph->getValue("a")->id(), a);
frame.setIValue(graph->getValue("b")->id(), b);
auto kernel = ScalarBinaryOpKernel(&node);
kernel.computeInternal(frame);
auto expected = a + b;
EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
}
} // namespace torch::nativert