diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp index 3993c63b170..7d431776a97 100644 --- a/test/cpp/jit/test_autodiff.cpp +++ b/test/cpp/jit/test_autodiff.cpp @@ -1,5 +1,4 @@ -#include - +#include "test/cpp/jit/test_base.h" #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/frontend/tracer.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" @@ -84,7 +83,7 @@ variable_list grad( fmap(inputs, get_edge)); } -TEST(AutodiffTest, ADFormulas) { +void testADFormulas() { const auto cast = [](const Variable& v) { return static_cast(v); }; @@ -175,7 +174,7 @@ TEST(AutodiffTest, ADFormulas) { } } -TEST(AutodiffTest, Differentiate) { +void testDifferentiate() { // Note: can't use IRParser for this test due to issue #23989 auto graph = std::make_shared(); std::vector sizes{2, 3, 4}; @@ -230,7 +229,7 @@ TEST(AutodiffTest, Differentiate) { ->run(*grad_spec.df); } -TEST(AutodiffTest, DifferentiateWithRequiresGrad) { +void testDifferentiateWithRequiresGrad() { const auto graph_string = R"IR( graph(%0 : Tensor, %1 : Tensor): diff --git a/test/cpp/jit/test_class_import.cpp b/test/cpp/jit/test_class_import.cpp index ffa845b3e2a..82bc0cf3bcc 100644 --- a/test/cpp/jit/test_class_import.cpp +++ b/test/cpp/jit/test_class_import.cpp @@ -1,7 +1,7 @@ -#include +#include +#include #include -#include #include #include #include @@ -45,7 +45,7 @@ static void import_libs( si.loadType(QualifiedName(class_name)); } -TEST(ClassImportTest, Basic) { +void testClassImport() { auto cu1 = std::make_shared(); auto cu2 = std::make_shared(); std::vector constantTable; @@ -80,7 +80,7 @@ TEST(ClassImportTest, Basic) { ASSERT_FALSE(c); } -TEST(ClassImportTest, ScriptObject) { +void testScriptObject() { Module m1("m1"); Module m2("m2"); std::vector constantTable; @@ -114,7 +114,7 @@ def __init__(self, x): return x )JIT"; -TEST(ClassImportTest, ClassDerive) { +void testClassDerive() { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu); const auto self = SimpleSelf(cls); @@ -142,7 +142,7 @@ class FooBar1234(Module): return (self.f).top() )JIT"; -TEST(ClassImportTest, CustomClass) { +void testSaveLoadTorchbind() { auto cu1 = std::make_shared(); std::vector constantTable; // Import different versions of FooTest into two namespaces. diff --git a/test/cpp/jit/test_class_parser.cpp b/test/cpp/jit/test_class_parser.cpp index a5b19f63fd3..45e37103bb5 100644 --- a/test/cpp/jit/test_class_parser.cpp +++ b/test/cpp/jit/test_class_parser.cpp @@ -1,5 +1,3 @@ -#include - #include #include #include @@ -17,7 +15,7 @@ const auto testSource = R"JIT( an_attribute : Tensor )JIT"; -TEST(ClassParserTest, Basic) { +void testClassParser() { Parser p(std::make_shared(testSource)); std::vector definitions; std::vector resolvers; diff --git a/test/cpp/jit/test_cleanup_passes.cpp b/test/cpp/jit/test_cleanup_passes.cpp index 38ceef932eb..2f2ca4e0a19 100644 --- a/test/cpp/jit/test_cleanup_passes.cpp +++ b/test/cpp/jit/test_cleanup_passes.cpp @@ -1,19 +1,19 @@ -#include - #include #include #include #include +#include "test/cpp/jit/test_base.h" namespace torch { namespace jit { -TEST(CleanupPassTest, Basic) { +void testCleanUpPasses() { // Tests stability of clean up passes when dealing with constant pooling // and constant propagation. - auto graph = std::make_shared(); - parseIR( - R"IR( + { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%cond.1 : Tensor, %suffix.1 : str): %3 : bool = aten::Bool(%cond.1) # o.py:6:7 @@ -31,19 +31,20 @@ graph(%cond.1 : Tensor, -> (%12) return (%25) )IR", - &*graph); - runCleanupPasses(graph); - testing::FileCheck() - .check_count( - "prim::Constant[value=\"same string with a twist\"]", - 1, - /*exactly=*/true) - ->run(*graph); + &*graph); + runCleanupPasses(graph); + testing::FileCheck() + .check_count( + "prim::Constant[value=\"same string with a twist\"]", + 1, + /*exactly=*/true) + ->run(*graph); - auto graph_after_pass_once = graph->toString(); - runCleanupPasses(graph); - auto graph_after_pass_twice = graph->toString(); - ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice); + auto graph_after_pass_once = graph->toString(); + runCleanupPasses(graph); + auto graph_after_pass_twice = graph->toString(); + ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice); + } } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_code_template.cpp b/test/cpp/jit/test_code_template.cpp index 35897474f1f..e4d7d1ef856 100644 --- a/test/cpp/jit/test_code_template.cpp +++ b/test/cpp/jit/test_code_template.cpp @@ -1,6 +1,6 @@ -#include +#include "test/cpp/jit/test_base.h" +#include "test/cpp/jit/test_utils.h" -#include #include "torch/csrc/jit/frontend/code_template.h" namespace torch { @@ -33,29 +33,31 @@ static const auto ct_expect = R"( int notest(int a) )"; -TEST(TestCodeTemplate, Copying) { - TemplateEnv e; - e.s("hi", "foo"); - e.v("what", {"is", "this"}); - TemplateEnv c(e); - c.s("hi", "foo2"); - ASSERT_EQ(e.s("hi"), "foo"); - ASSERT_EQ(c.s("hi"), "foo2"); - ASSERT_EQ(e.v("what")[0], "is"); -} +void testCodeTemplate() { + { + TemplateEnv e; + e.s("hi", "foo"); + e.v("what", {"is", "this"}); + TemplateEnv c(e); + c.s("hi", "foo2"); + ASSERT_EQ(e.s("hi"), "foo"); + ASSERT_EQ(c.s("hi"), "foo2"); + ASSERT_EQ(e.v("what")[0], "is"); + } -TEST(TestCodeTemplate, Formatting) { - TemplateEnv e; - e.v("args", {"hi", "8"}); - e.v("bar", {"what\non many\nlines...", "7"}); - e.s("a", "3"); - e.s("b", "4"); - e.v("stuff", {"things...", "others"}); - e.v("empty", {}); - auto s = ct.format(e); - // std::cout << "'" << s << "'\n"; - // std::cout << "'" << ct_expect << "'\n"; - ASSERT_EQ(s, ct_expect); + { + TemplateEnv e; + e.v("args", {"hi", "8"}); + e.v("bar", {"what\non many\nlines...", "7"}); + e.s("a", "3"); + e.s("b", "4"); + e.v("stuff", {"things...", "others"}); + e.v("empty", {}); + auto s = ct.format(e); + // std::cout << "'" << s << "'\n"; + // std::cout << "'" << ct_expect << "'\n"; + ASSERT_EQ(s, ct_expect); + } } } // namespace jit diff --git a/test/cpp/jit/test_constant_pooling.cpp b/test/cpp/jit/test_constant_pooling.cpp index c8cb58e1886..b949c9a45b2 100644 --- a/test/cpp/jit/test_constant_pooling.cpp +++ b/test/cpp/jit/test_constant_pooling.cpp @@ -1,10 +1,9 @@ -#include - #include #include #include #include #include +#include "test/cpp/jit/test_base.h" #include #include @@ -12,26 +11,26 @@ namespace torch { namespace jit { -TEST(ConstantPoolingTest, Int) { - auto graph = std::make_shared(); - parseIR( - R"IR( +void testConstantPooling() { + { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(): %8 : int = prim::Constant[value=1]() %10 : int = prim::Constant[value=1]() return (%8, %10) )IR", - &*graph); - ConstantPooling(graph); - testing::FileCheck() - .check_count("prim::Constant", 1, /*exactly*/ true) - ->run(*graph); -} - -TEST(ConstantPoolingTest, PoolingAcrossBlocks) { - auto graph = std::make_shared(); - parseIR( - R"IR( + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant", 1, /*exactly*/ true) + ->run(*graph); + } + { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(%cond : Tensor): %a : str = prim::Constant[value="bcd"]() %3 : bool = aten::Bool(%cond) @@ -45,18 +44,17 @@ graph(%cond : Tensor): %7 : (str, str) = prim::TupleConstruct(%a, %b) return (%7) )IR", - &*graph); - ConstantPooling(graph); - testing::FileCheck() - .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true) - ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true) - ->run(*graph); -} - -TEST(ConstantPoolingTest, PoolingDifferentDevices) { - auto graph = std::make_shared(); - parseIR( - R"IR( + &*graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true) + ->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true) + ->run(*graph); + } + { + auto graph = std::make_shared(); + parseIR( + R"IR( graph(): %2 : int = prim::Constant[value=2]() %1 : int = prim::Constant[value=1]() @@ -72,21 +70,22 @@ graph(): prim::Print(%x, %y, %z) return (%1) )IR", - &*graph); - // three tensors created - two different devices among the three - // don't have good support for parsing tensor constants - ConstantPropagation(graph); - ConstantPooling(graph); - testing::FileCheck() - .check_count( - "Float(2:1, requires_grad=0, device=cpu) = prim::Constant", - 1, - /*exactly*/ true) - ->check_count( - "Long(2:1, requires_grad=0, device=cpu) = prim::Constant", - 1, - /*exactly*/ true) - ->run(*graph); + &*graph); + // three tensors created - two different devices among the three + // don't have good support for parsing tensor constants + ConstantPropagation(graph); + ConstantPooling(graph); + testing::FileCheck() + .check_count( + "Float(2:1, requires_grad=0, device=cpu) = prim::Constant", + 1, + /*exactly*/ true) + ->check_count( + "Long(2:1, requires_grad=0, device=cpu) = prim::Constant", + 1, + /*exactly*/ true) + ->run(*graph); + } } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_create_autodiff_subgraphs.cpp b/test/cpp/jit/test_create_autodiff_subgraphs.cpp index e97043f84d2..8da6d9d6a1b 100644 --- a/test/cpp/jit/test_create_autodiff_subgraphs.cpp +++ b/test/cpp/jit/test_create_autodiff_subgraphs.cpp @@ -1,5 +1,4 @@ -#include - +#include "test/cpp/jit/test_base.h" #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" @@ -7,7 +6,7 @@ namespace torch { namespace jit { -TEST(CreateAutodiffSubgraphsTest, Basic) { +void testCreateAutodiffSubgraphs() { auto graph = build_lstm(); CreateAutodiffSubgraphs(graph, /*threshold=*/2); // all of the ops are within the DifferentiableGraph diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp index 25c518d3142..543fbc20eb3 100644 --- a/test/cpp/jit/test_custom_class.cpp +++ b/test/cpp/jit/test_custom_class.cpp @@ -1,5 +1,3 @@ -#include - #include #include @@ -320,7 +318,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { } // namespace -TEST(CustomClassTest, TorchbindIValueAPI) { +void testTorchbindIValueAPI() { script::Module m("m"); // test make_custom_class API diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index d3f61268e8f..529b36385bd 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -1,5 +1,4 @@ -#include - +#include "test/cpp/jit/test_base.h" #include "test/cpp/jit/test_utils.h" #include "torch/csrc/jit/ir/alias_analysis.h" @@ -12,135 +11,134 @@ namespace torch { namespace jit { -TEST(CustomOperatorTest, InferredSchema) { - torch::RegisterOperators reg( - "foo::bar", [](double a, at::Tensor b) { return a + b; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); - ASSERT_EQ(ops.size(), 1); +void testCustomOperators() { + { + torch::RegisterOperators reg( + "foo::bar", [](double a, at::Tensor b) { return a + b; }); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + ASSERT_EQ(ops.size(), 1); - auto& op = ops.front(); - ASSERT_EQ(op->schema().name(), "foo::bar"); + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foo::bar"); - ASSERT_EQ(op->schema().arguments().size(), 2); - ASSERT_EQ(op->schema().arguments()[0].name(), "_0"); - ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); - ASSERT_EQ(op->schema().arguments()[1].name(), "_1"); - ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); + ASSERT_EQ(op->schema().arguments().size(), 2); + ASSERT_EQ(op->schema().arguments()[0].name(), "_0"); + ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); + ASSERT_EQ(op->schema().arguments()[1].name(), "_1"); + ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); - ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); + ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); - Stack stack; - push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); - at::Tensor output; - pop(stack, output); + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(&stack); + at::Tensor output; + pop(stack, output); - ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + } + { + torch::RegisterOperators reg( + "foo::bar_with_schema(float a, Tensor b) -> Tensor", + [](double a, at::Tensor b) { return a + b; }); + + auto& ops = + getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foo::bar_with_schema"); + + ASSERT_EQ(op->schema().arguments().size(), 2); + ASSERT_EQ(op->schema().arguments()[0].name(), "a"); + ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); + ASSERT_EQ(op->schema().arguments()[1].name(), "b"); + ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); + + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); + + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(&stack); + at::Tensor output; + pop(stack, output); + + ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + } + { + // Check that lists work well. + torch::RegisterOperators reg( + "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]", + [](torch::List ints, + torch::List floats, + torch::List tensors) { return floats; }); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foo::lists"); + + ASSERT_EQ(op->schema().arguments().size(), 3); + ASSERT_EQ(op->schema().arguments()[0].name(), "ints"); + ASSERT_TRUE( + op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts())); + ASSERT_EQ(op->schema().arguments()[1].name(), "floats"); + ASSERT_TRUE( + op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats())); + ASSERT_EQ(op->schema().arguments()[2].name(), "tensors"); + ASSERT_TRUE( + op->schema().arguments()[2].type()->isSubtypeOf(ListType::ofTensors())); + + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_TRUE( + op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats())); + + Stack stack; + push(stack, c10::List({1, 2})); + push(stack, c10::List({1.0, 2.0})); + push(stack, c10::List({at::ones(5)})); + op->getOperation()(&stack); + c10::List output; + pop(stack, output); + + ASSERT_EQ(output.size(), 2); + ASSERT_EQ(output.get(0), 1.0); + ASSERT_EQ(output.get(1), 2.0); + } + { + torch::RegisterOperators reg( + "foo::lists2(Tensor[] tensors) -> Tensor[]", + [](torch::List tensors) { return tensors; }); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foo::lists2"); + + ASSERT_EQ(op->schema().arguments().size(), 1); + ASSERT_EQ(op->schema().arguments()[0].name(), "tensors"); + ASSERT_TRUE( + op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors())); + + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_TRUE( + op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors())); + + Stack stack; + push(stack, c10::List({at::ones(5)})); + op->getOperation()(&stack); + c10::List output; + pop(stack, output); + + ASSERT_EQ(output.size(), 1); + ASSERT_TRUE(output.get(0).allclose(at::ones(5))); + } } -TEST(CustomOperatorTest, ExplicitSchema) { - torch::RegisterOperators reg( - "foo::bar_with_schema(float a, Tensor b) -> Tensor", - [](double a, at::Tensor b) { return a + b; }); - - auto& ops = - getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); - ASSERT_EQ(ops.size(), 1); - - auto& op = ops.front(); - ASSERT_EQ(op->schema().name(), "foo::bar_with_schema"); - - ASSERT_EQ(op->schema().arguments().size(), 2); - ASSERT_EQ(op->schema().arguments()[0].name(), "a"); - ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); - ASSERT_EQ(op->schema().arguments()[1].name(), "b"); - ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); - - ASSERT_EQ(op->schema().returns().size(), 1); - ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); - - Stack stack; - push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); - at::Tensor output; - pop(stack, output); - - ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); -} - -TEST(CustomOperatorTest, ListParameters) { - // Check that lists work well. - torch::RegisterOperators reg( - "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]", - [](torch::List ints, - torch::List floats, - torch::List tensors) { return floats; }); - - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); - ASSERT_EQ(ops.size(), 1); - - auto& op = ops.front(); - ASSERT_EQ(op->schema().name(), "foo::lists"); - - ASSERT_EQ(op->schema().arguments().size(), 3); - ASSERT_EQ(op->schema().arguments()[0].name(), "ints"); - ASSERT_TRUE( - op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts())); - ASSERT_EQ(op->schema().arguments()[1].name(), "floats"); - ASSERT_TRUE( - op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats())); - ASSERT_EQ(op->schema().arguments()[2].name(), "tensors"); - ASSERT_TRUE( - op->schema().arguments()[2].type()->isSubtypeOf(ListType::ofTensors())); - - ASSERT_EQ(op->schema().returns().size(), 1); - ASSERT_TRUE( - op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats())); - - Stack stack; - push(stack, c10::List({1, 2})); - push(stack, c10::List({1.0, 2.0})); - push(stack, c10::List({at::ones(5)})); - op->getOperation()(&stack); - c10::List output; - pop(stack, output); - - ASSERT_EQ(output.size(), 2); - ASSERT_EQ(output.get(0), 1.0); - ASSERT_EQ(output.get(1), 2.0); -} - -TEST(CustomOperatorTest, ListParameters2) { - torch::RegisterOperators reg( - "foo::lists2(Tensor[] tensors) -> Tensor[]", - [](torch::List tensors) { return tensors; }); - - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); - ASSERT_EQ(ops.size(), 1); - - auto& op = ops.front(); - ASSERT_EQ(op->schema().name(), "foo::lists2"); - - ASSERT_EQ(op->schema().arguments().size(), 1); - ASSERT_EQ(op->schema().arguments()[0].name(), "tensors"); - ASSERT_TRUE( - op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors())); - - ASSERT_EQ(op->schema().returns().size(), 1); - ASSERT_TRUE( - op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors())); - - Stack stack; - push(stack, c10::List({at::ones(5)})); - op->getOperation()(&stack); - c10::List output; - pop(stack, output); - - ASSERT_EQ(output.size(), 1); - ASSERT_TRUE(output.get(0).allclose(at::ones(5))); -} - -TEST(CustomOperatorTest, Aliasing) { +void testCustomOperatorAliasing() { torch::RegisterOperators reg( "foo::aliasing", [](at::Tensor a, at::Tensor b) -> at::Tensor { a.add_(b); @@ -184,65 +182,77 @@ graph(%x: Tensor, %y: Tensor): } } -static constexpr char op_list[] = "foofoo::bar.template;foo::another"; +void testIValueKWargs() { + const auto text = R"( + def foo(a : int, b : int, c : int = 4): + return a + 2*b + 3*c + )"; + auto cu = compile(text); + auto result = cu->get_function("foo")({1}, {{"b", 3}}); + ASSERT_EQ(result.toInt(), 19); +} + +void testTemplatedOperatorCreator() { + constexpr char op_list[] = "foofoo::bar.template;foo::another"; #define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \ torch::detail::SelectiveStr(n) -TEST(TestCustomOperator, OperatorGeneratorUndeclared) { - // Try to register an op name that does not exist in op_list. - // Expected: the op name is not registered. - torch::jit::RegisterOperators reg({OperatorGenerator( - TORCH_SELECTIVE_NAME_IN_SCHEMA( - op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), - [](Stack* stack) { - double a; - at::Tensor b; - pop(stack, a, b); - push(stack, a + b); - }, - aliasAnalysisFromSchema())}); + { + // Try to register an op name that does not exist in op_list. + // Expected: the op name is not registered. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); - ASSERT_EQ(ops.size(), 0); -} + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + ASSERT_EQ(ops.size(), 0); + } -TEST(TestCustomOperator, OperatorGeneratorBasic) { - // The operator should be successfully registered since its name is in the - // whitelist. - torch::jit::RegisterOperators reg({OperatorGenerator( - TORCH_SELECTIVE_NAME_IN_SCHEMA( - op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), - [](Stack* stack) { - double a; - at::Tensor b; - pop(stack, a, b); - push(stack, a + b); - }, - aliasAnalysisFromSchema())}); + { + // The operator should be successfully registered since its name is in the + // whitelist. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); - ASSERT_EQ(ops.size(), 1); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + ASSERT_EQ(ops.size(), 1); - auto& op = ops.front(); - ASSERT_EQ(op->schema().name(), "foofoo::bar"); + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foofoo::bar"); - ASSERT_EQ(op->schema().arguments().size(), 2); - ASSERT_EQ(op->schema().arguments()[0].name(), "a"); - ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); - ASSERT_EQ(op->schema().arguments()[1].name(), "b"); - ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); + ASSERT_EQ(op->schema().arguments().size(), 2); + ASSERT_EQ(op->schema().arguments()[0].name(), "a"); + ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); + ASSERT_EQ(op->schema().arguments()[1].name(), "b"); + ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); - ASSERT_EQ(op->schema().returns().size(), 1); - ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); - Stack stack; - push(stack, 2.0f, at::ones(5)); - op->getOperation()(&stack); - at::Tensor output; - pop(stack, output); + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(&stack); + at::Tensor output; + pop(stack, output); - ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + } } } // namespace jit diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 92baba1168d..953d1bf42fc 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -2225,15 +2225,5 @@ void testProfilerDisableInCallback() { t.join(); } -void testIValueKWargs() { - const auto text = R"( - def foo(a : int, b : int, c : int = 4): - return a + 2*b + 3*c - )"; - auto cu = compile(text); - auto result = cu->get_function("foo")({1}, {{"b", 3}}); - ASSERT_EQ(result.toInt(), 19); -} - } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 8f43882c9e2..45d7f48b1f8 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -9,14 +9,22 @@ namespace torch { namespace jit { #define TH_FORALL_TESTS(_) \ + _(ADFormulas) \ _(Attributes) \ _(Blocks) \ _(CallStack) \ _(CallStackCaching) \ + _(CodeTemplate) \ _(ControlFlow) \ + _(CreateAutodiffSubgraphs) \ + _(CustomOperators) \ + _(CustomOperatorAliasing) \ + _(TemplatedOperatorCreator) \ _(IValueKWargs) \ _(CustomFusion) \ _(SchemaMatching) \ + _(Differentiate) \ + _(DifferentiateWithRequiresGrad) \ _(FromQualString) \ _(InternedStrings) \ _(PassManagement) \ @@ -27,9 +35,12 @@ namespace jit { _(SubgraphUtils) \ _(SubgraphUtilsVmap) \ _(IRParser) \ + _(ConstantPooling) \ + _(CleanUpPasses) \ _(THNNConv) \ _(ATenNativeBatchNorm) \ _(NoneSchemaMatch) \ + _(ClassParser) \ _(UnifyTypes) \ _(Profiler) \ _(FallbackGraphs) \ @@ -50,11 +61,15 @@ namespace jit { _(ModuleDeepcopyAliasing) \ _(ModuleDefine) \ _(QualifiedName) \ + _(ClassImport) \ + _(ScriptObject) \ _(ExtraFilesHookPreference) \ _(SaveExtraFilesHook) \ _(TypeTags) \ _(DCE) \ _(CustomFusionNestedBlocks) \ + _(ClassDerive) \ + _(SaveLoadTorchbind) \ _(ModuleInterfaceSerialization) \ _(ModuleCloneWithModuleInterface) \ _(ClassTypeAddRemoveAttr) \ @@ -85,6 +100,7 @@ namespace jit { _(LiteInterpreterHierarchyModuleInfo) \ _(LiteInterpreterDuplicatedClassTypeModuleInfo) \ _(LiteInterpreterEval) \ + _(TorchbindIValueAPI) \ _(LiteInterpreterDict) \ _(LiteInterpreterFindAndRunMethod) \ _(LiteInterpreterFindWrongMethodName) \