#include "gtest/gtest.h" #include #include #include #include #include #include #include template < typename T, typename = torch::enable_if_t::value>> bool f(T&& m) { return false; } template torch::detail::enable_if_module_t f(T&& m) { return true; } TEST(TestStatic, All_Of){ EXPECT_TRUE(torch::all_of<>::value); EXPECT_TRUE(torch::all_of::value); EXPECT_TRUE((torch::all_of::value)); EXPECT_FALSE(torch::all_of::value); EXPECT_FALSE((torch::all_of::value)); EXPECT_FALSE((torch::all_of::value)); } TEST(TestStatic, Any_Of){ EXPECT_FALSE(torch::any_of<>::value); EXPECT_TRUE(bool((torch::any_of::value))); EXPECT_TRUE(bool((torch::any_of::value))); EXPECT_FALSE(bool((torch::any_of::value))); } TEST(TestStatic, Enable_If_Module){ EXPECT_TRUE(f(torch::nn::LinearImpl(1, 2))); EXPECT_FALSE(f(5)); EXPECT_TRUE(torch::detail::check_not_lvalue_references()); EXPECT_TRUE((torch::detail::check_not_lvalue_references())); EXPECT_FALSE( (torch::detail::check_not_lvalue_references())); EXPECT_TRUE(torch::detail::check_not_lvalue_references()); EXPECT_FALSE(torch::detail::check_not_lvalue_references()); } TEST(TestStatic, Apply){ std::vector v; torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); EXPECT_EQ(v.size(), 5); for (size_t i = 0; i < v.size(); ++i) { EXPECT_EQ(v.at(i), i + 1); } }