[static-runtime] add nnc codegen for aten::div (#76903)

Differential Revision: D36151087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76903
Approved by: https://github.com/mikeiovine
This commit is contained in:
Max Podkorytov
2022-06-22 05:47:44 +00:00
committed by PyTorch MergeBot
parent 0b349f7e69
commit bf75708ce4
6 changed files with 92 additions and 9 deletions

View File

@@ -4349,7 +4349,7 @@ TEST(StaticRuntime, autogen_take_along_dim) {
)IR";
auto self0 = at::rand({6, 6, 6});
auto indices0 = at::argsort(self0, 1);
auto indices0 = at::argsort(self0, 1, true);
auto dim0 = 1;
std::vector<IValue> args{self0, indices0, dim0};
testStaticRuntime(
@@ -4361,7 +4361,7 @@ TEST(StaticRuntime, autogen_take_along_dim) {
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto indices1 = at::argsort(self1, 1);
auto indices1 = at::argsort(self1, 1, true);
auto dim1 = 1;
std::vector<IValue> args2{self1, indices1, dim1};
testStaticRuntime(

View File

@@ -877,15 +877,30 @@ TEST(StaticRuntime, Div) {
return torch.div(a, b, rounding_mode=c).clone()
)JIT";
const auto div_strided = R"JIT(
def forward(self, a: Tensor, b: Tensor):
a_strided = torch.transpose(a, 0, 1)
b_strided = torch.transpose(b, 0, 1)
return torch.div(a_strided, b_strided).clone()
)JIT";
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
auto bs = at::randn({3, 2}).transpose(0, 1);
auto c = at::randn({4, 3, 2});
auto d = at::randn({4, 3, 2});
auto ds = at::randn({3, 4, 2}).transpose(0, 1);
std::vector<IValue> args0{a, b};
testStaticRuntime(div_tensor, args0);
testStaticRuntime(div_tensor, args0, {c, d});
testStaticRuntime(div_strided, args0);
testStaticRuntime(div_strided, args0, {c, d});
testStaticRuntime(div_tensor, {a, bs});
testStaticRuntime(div_tensor, {a, bs}, {c, ds});
std::vector<IValue> args1{a, 3};
testStaticRuntime(div_scalar, args1);
testStaticRuntime(div_scalar, args1, {c, 4});

View File

@@ -1850,7 +1850,8 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
return [te = createDiv()](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
c10::optional<c10::string_view> rounding_mode = c10::nullopt;
if (p_node->num_inputs() > 2) {
@@ -1861,12 +1862,37 @@ REGISTER_OPERATOR_FUNCTOR(aten::div, aten_div, [](Node* n) -> SROperator {
: at::native::wrapped_scalar_tensor(p_node->Input(1).toScalar());
if (p_node->Output(0).isNone()) {
p_node->Output(0) = at::cpu::div(in0_t, in1_t, rounding_mode);
return;
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
if (in0_t.sizes() == in1_t.sizes() &&
in0_t.scalar_type() == in1_t.scalar_type() &&
in0_t.strides() == in1_t.strides() && in0_t.is_contiguous() &&
in0_t.scalar_type() == at::kFloat) {
int64_t dim = in0_t.numel();
int i_rounding_mode = 0;
if (rounding_mode && !rounding_mode.value().empty()) {
const char peek_rounding_mode = rounding_mode.value().at(0);
if (peek_rounding_mode == 't') {
// trunc after div
i_rounding_mode = 1;
} else if (peek_rounding_mode == 'f') {
// floor after div
i_rounding_mode = 2;
}
}
at::native::resize_(out_t, in0_t.sizes());
te->call(
{out_t.data_ptr(),
in0_t.data_ptr(),
in1_t.data_ptr(),
&i_rounding_mode,
&dim});
} else {
fastResizeToZero(out_t);
at::cpu::div_out(out_t, in0_t, in1_t, rounding_mode);
}
};
});

View File

@@ -4,6 +4,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
#include <torch/csrc/jit/tensorexpr/operators/operators.h>
@@ -113,6 +114,46 @@ void updateNNCCache(NodeKind kind, std::shared_ptr<TEWrapper> code) {
} // namespace
std::shared_ptr<TEWrapper> createDiv() {
auto wrap = lookupNNCCache(aten::div);
if (wrap) {
return wrap;
}
wrap = std::make_shared<TEWrapper>();
auto dim = VarHandle("dim", kInt);
auto mode = VarHandle("mode", kInt);
BufHandle A("A", {dim}, kFloat);
BufHandle B("B", {dim}, kFloat);
using axis = const VarHandle&;
Tensor C = Compute("C", {dim}, [&](axis x) {
auto true_div_result = A.load(x) / B.load(x);
auto mode_default = IntImm::make(0);
auto mode_trunc = IntImm::make(1);
auto mode_floor = IntImm::make(2);
// this is a glorified ternary choice operator train
return CompareSelect::make(
mode,
mode_default,
true_div_result,
CompareSelect::make(
mode,
mode_trunc,
trunc(true_div_result),
floor(true_div_result),
kEQ),
kEQ);
});
wrap = wrapTECompute(wrap, C, {A, B, mode, dim});
updateNNCCache(aten::div, wrap);
return wrap;
}
std::shared_ptr<TEWrapper> createLogit() {
auto wrap = lookupNNCCache(aten::logit);
if (wrap) {

View File

@@ -33,6 +33,7 @@ class TEWrapper {
#endif
};
std::shared_ptr<TEWrapper> createDiv();
std::shared_ptr<TEWrapper> createLogit();
std::shared_ptr<TEWrapper> createRelu();
std::shared_ptr<TEWrapper> createTanh();

View File

@@ -102,9 +102,9 @@ def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> N
return
if op_name == "take_along_dim":
if index == 0:
arg_map["indices"] = "at::argsort(self0, 1)"
arg_map["indices"] = "at::argsort(self0, 1, true)"
else:
arg_map["indices"] = "at::argsort(self1, 1)"
arg_map["indices"] = "at::argsort(self1, 1, true)"
return
if op_name == "masked_select":
if index == 0: