mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[static-runtime] add nnc codegen for aten::div (#76903)
Differential Revision: D36151087 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76903 Approved by: https://github.com/mikeiovine
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b349f7e69
commit
bf75708ce4
@@ -4349,7 +4349,7 @@ TEST(StaticRuntime, autogen_take_along_dim) {
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto indices0 = at::argsort(self0, 1);
|
||||
auto indices0 = at::argsort(self0, 1, true);
|
||||
auto dim0 = 1;
|
||||
std::vector<IValue> args{self0, indices0, dim0};
|
||||
testStaticRuntime(
|
||||
@@ -4361,7 +4361,7 @@ TEST(StaticRuntime, autogen_take_along_dim) {
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto indices1 = at::argsort(self1, 1);
|
||||
auto indices1 = at::argsort(self1, 1, true);
|
||||
auto dim1 = 1;
|
||||
std::vector<IValue> args2{self1, indices1, dim1};
|
||||
testStaticRuntime(
|
||||
|
||||
@@ -877,15 +877,30 @@ TEST(StaticRuntime, Div) {
|
||||
return torch.div(a, b, rounding_mode=c).clone()
|
||||
)JIT";
|
||||
|
||||
const auto div_strided = R"JIT(
|
||||
def forward(self, a: Tensor, b: Tensor):
|
||||
a_strided = torch.transpose(a, 0, 1)
|
||||
b_strided = torch.transpose(b, 0, 1)
|
||||
return torch.div(a_strided, b_strided).clone()
|
||||
)JIT";
|
||||
|
||||
auto a = at::randn({2, 3});
|
||||
auto b = at::randn({2, 3});
|
||||
auto bs = at::randn({3, 2}).transpose(0, 1);
|
||||
auto c = at::randn({4, 3, 2});
|
||||
auto d = at::randn({4, 3, 2});
|
||||
auto ds = at::randn({3, 4, 2}).transpose(0, 1);
|
||||
|
||||
std::vector<IValue> args0{a, b};
|
||||
testStaticRuntime(div_tensor, args0);
|
||||
testStaticRuntime(div_tensor, args0, {c, d});
|
||||
|
||||
testStaticRuntime(div_strided, args0);
|
||||
testStaticRuntime(div_strided, args0, {c, d});
|
||||
|
||||
testStaticRuntime(div_tensor, {a, bs});
|
||||
testStaticRuntime(div_tensor, {a, bs}, {c, ds});
|
||||
|
||||
std::vector<IValue> args1{a, 3};
|
||||
testStaticRuntime(div_scalar, args1);
|
||||
testStaticRuntime(div_scalar, args1, {c, 4});
|
||||
|
||||
@@ -1850,7 +1850,8 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
|
||||
return [te = createDiv()](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
c10::optional<c10::string_view> rounding_mode = c10::nullopt;
|
||||
if (p_node->num_inputs() > 2) {
|
||||
@@ -1861,12 +1862,37 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
|
||||
: at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
|
||||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = at::cpu::div(in0_t, in1_t, rounding_mode);
|
||||
return;
|
||||
p_node->Output(0) = create_empty_from(in0_t);
|
||||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
|
||||
|
||||
if (in0_t.sizes() == in1_t.sizes() &&
|
||||
in0_t.scalar_type() == in1_t.scalar_type() &&
|
||||
in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() &&
|
||||
in0_t.scalar_type() == at::kFloat) {
|
||||
int64_t dim = in0_t.numel();
|
||||
int i_rounding_mode = 0;
|
||||
if (rounding_mode && !rounding_mode.value().empty()) {
|
||||
const char peek_rounding_mode = rounding_mode.value().at(0);
|
||||
if (peek_rounding_mode == 't') {
|
||||
// trunc after div
|
||||
i_rounding_mode = 1;
|
||||
} else if (peek_rounding_mode == 'f') {
|
||||
// floor after div
|
||||
i_rounding_mode = 2;
|
||||
}
|
||||
}
|
||||
at::native::resize_(out_t, in0_t.sizes());
|
||||
te->call(
|
||||
{out_t.data_ptr(),
|
||||
in0_t.data_ptr(),
|
||||
in1_t.data_ptr(),
|
||||
&i_rounding_mode,
|
||||
&dim});
|
||||
} else {
|
||||
fastResizeToZero(out_t);
|
||||
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/tensorexpr/expr.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
|
||||
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
|
||||
|
||||
@@ -113,6 +114,46 @@ void updateNNCCache(NodeKind kind, std::shared_ptr<TEWrapper> code) {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<TEWrapper> createDiv() {
|
||||
auto wrap = lookupNNCCache(aten::div);
|
||||
if (wrap) {
|
||||
return wrap;
|
||||
}
|
||||
wrap = std::make_shared<TEWrapper>();
|
||||
|
||||
auto dim = VarHandle("dim", kInt);
|
||||
auto mode = VarHandle("mode", kInt);
|
||||
BufHandle A("A", {dim}, kFloat);
|
||||
BufHandle B("B", {dim}, kFloat);
|
||||
|
||||
using axis = const VarHandle&;
|
||||
Tensor C = Compute("C", {dim}, [&](axis x) {
|
||||
auto true_div_result = A.load(x) / B.load(x);
|
||||
|
||||
auto mode_default = IntImm::make(0);
|
||||
auto mode_trunc = IntImm::make(1);
|
||||
auto mode_floor = IntImm::make(2);
|
||||
|
||||
// this is a glorified ternary choice operator train
|
||||
return CompareSelect::make(
|
||||
mode,
|
||||
mode_default,
|
||||
true_div_result,
|
||||
CompareSelect::make(
|
||||
mode,
|
||||
mode_trunc,
|
||||
trunc(true_div_result),
|
||||
floor(true_div_result),
|
||||
kEQ),
|
||||
kEQ);
|
||||
});
|
||||
|
||||
wrap = wrapTECompute(wrap, C, {A, B, mode, dim});
|
||||
|
||||
updateNNCCache(aten::div, wrap);
|
||||
return wrap;
|
||||
}
|
||||
|
||||
std::shared_ptr<TEWrapper> createLogit() {
|
||||
auto wrap = lookupNNCCache(aten::logit);
|
||||
if (wrap) {
|
||||
|
||||
@@ -33,6 +33,7 @@ class TEWrapper {
|
||||
#endif
|
||||
};
|
||||
|
||||
std::shared_ptr<TEWrapper> createDiv();
|
||||
std::shared_ptr<TEWrapper> createLogit();
|
||||
std::shared_ptr<TEWrapper> createRelu();
|
||||
std::shared_ptr<TEWrapper> createTanh();
|
||||
|
||||
@@ -102,9 +102,9 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
|
||||
return
|
||||
if op_name == "take_along_dim":
|
||||
if index == 0:
|
||||
arg_map["indices"] = "at::argsort(self0, 1)"
|
||||
arg_map["indices"] = "at::argsort(self0, 1, true)"
|
||||
else:
|
||||
arg_map["indices"] = "at::argsort(self1, 1)"
|
||||
arg_map["indices"] = "at::argsort(self1, 1, true)"
|
||||
return
|
||||
if op_name == "masked_select":
|
||||
if index == 0:
|
||||
|
||||
Reference in New Issue
Block a user