2020-09-25 11:35:39 -07:00
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
|
2019-11-21 12:58:47 -08:00
|
|
|
#include <test/cpp/jit/test_utils.h>
|
2020-07-31 10:21:43 -07:00
|
|
|
|
|
|
|
|
#include <ATen/core/qualified_name.h>
|
2021-02-18 00:15:37 -08:00
|
|
|
#include <torch/csrc/jit/api/module.h>
|
2020-07-31 10:21:43 -07:00
|
|
|
#include <torch/csrc/jit/frontend/resolver.h>
|
|
|
|
|
#include <torch/csrc/jit/serialization/import.h>
|
|
|
|
|
#include <torch/csrc/jit/serialization/import_source.h>
|
2021-02-18 00:15:37 -08:00
|
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
2019-11-21 12:58:47 -08:00
|
|
|
#include <torch/torch.h>
|
|
|
|
|
|
|
|
|
|
namespace torch {
|
|
|
|
|
namespace jit {
|
|
|
|
|
|
2024-12-10 15:42:28 +00:00
|
|
|
static constexpr std::string_view moduleInterfaceSrc = R"JIT(
|
2020-07-31 10:21:43 -07:00
|
|
|
class OneInterface(ModuleInterface):
|
|
|
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
|
|
|
pass
|
|
|
|
|
)JIT";
|
|
|
|
|
|
|
|
|
|
static const std::vector<std::string> subModuleMethodsSrc = {R"JIT(
|
|
|
|
|
def one(self, x: Tensor, y: Tensor) -> Tensor:
|
|
|
|
|
return self.attr * x + y + 1
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
|
|
|
return self.attr + x
|
|
|
|
|
)JIT"};
|
|
|
|
|
|
2021-10-18 23:15:25 -07:00
|
|
|
static const std::string parentForward = R"JIT(
|
2020-07-31 10:21:43 -07:00
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
|
|
|
return self.subMod1.one(x, x) + self.subMod2.one(x, x)
|
|
|
|
|
)JIT";
|
|
|
|
|
|
|
|
|
|
static void import_libs(
|
|
|
|
|
std::shared_ptr<CompilationUnit> cu,
|
|
|
|
|
const std::string& class_name,
|
|
|
|
|
const std::shared_ptr<Source>& src,
|
|
|
|
|
const std::vector<at::IValue>& tensor_table) {
|
|
|
|
|
SourceImporter si(
|
|
|
|
|
cu,
|
|
|
|
|
&tensor_table,
|
|
|
|
|
[&](const std::string& name) -> std::shared_ptr<Source> { return src; },
|
|
|
|
|
/*version=*/2);
|
|
|
|
|
si.loadType(QualifiedName(class_name));
|
|
|
|
|
}
|
|
|
|
|
|
2020-12-11 11:15:53 -08:00
|
|
|
TEST(ModuleAPITest, MethodRunAsync) {
|
|
|
|
|
// Module m("m");
|
|
|
|
|
// m.define(R"(
|
|
|
|
|
// def forward(self):
|
|
|
|
|
// r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
|
|
|
|
// r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
|
|
|
|
|
// return r1.wait() + r2.wait()
|
|
|
|
|
// )");
|
|
|
|
|
std::string filePath(__FILE__);
|
|
|
|
|
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
|
|
|
|
|
// borrow model file from TEST(GraphExecutorTest, runAsync_executor)
|
|
|
|
|
testModelFile.append("test_interpreter_async.pt");
|
|
|
|
|
auto m = load(testModelFile);
|
|
|
|
|
|
|
|
|
|
auto counter = 0;
|
|
|
|
|
std::mutex mtx;
|
|
|
|
|
|
|
|
|
|
auto launcher = [&](std::function<void()> f) {
|
|
|
|
|
mtx.lock();
|
|
|
|
|
++counter;
|
|
|
|
|
mtx.unlock();
|
2022-12-04 12:50:14 +00:00
|
|
|
at::launch(std::move(f));
|
2020-12-11 11:15:53 -08:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto method = m.get_method("forward");
|
|
|
|
|
|
|
|
|
|
std::vector<IValue> stack;
|
|
|
|
|
auto kwargs = std::unordered_map<std::string, at::IValue>();
|
|
|
|
|
auto future = method.run_async(stack, kwargs, launcher);
|
|
|
|
|
|
|
|
|
|
future->wait();
|
|
|
|
|
|
2023-10-16 23:06:02 +00:00
|
|
|
// expect 2 forks and 2 wait callbacks being executed on provided taskLauncher
|
2020-12-11 11:15:53 -08:00
|
|
|
// but ivalue::Future would be marked completed and release wait before
|
|
|
|
|
// finishing all callbacks
|
|
|
|
|
ASSERT_GE(counter, 2);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, Clone) {
|
2020-01-15 11:15:21 -08:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
// creating child module
|
|
|
|
|
auto child = ClassType::create("child", cu, true);
|
|
|
|
|
auto attr_name = "attr";
|
|
|
|
|
child->addAttribute(attr_name, IntType::get());
|
|
|
|
|
Module c1(cu, child);
|
|
|
|
|
auto v1 = IValue(2);
|
2020-03-26 11:15:49 -07:00
|
|
|
c1.register_attribute(attr_name, IntType::get(), v1, false);
|
2020-01-15 11:15:21 -08:00
|
|
|
Module c2(cu, child);
|
|
|
|
|
auto v2 = IValue(3);
|
2020-03-26 11:15:49 -07:00
|
|
|
c2.register_attribute(attr_name, IntType::get(), v2, false);
|
2020-01-15 11:15:21 -08:00
|
|
|
|
|
|
|
|
// attach two child module instance to parent that shares
|
|
|
|
|
// ClassType
|
2020-05-18 23:21:27 -07:00
|
|
|
auto parent = ClassType::create("parent", cu, true);
|
2020-01-15 11:15:21 -08:00
|
|
|
Module p(cu, parent);
|
|
|
|
|
p.register_attribute("c1", c1.type(), c1._ivalue(), false);
|
|
|
|
|
p.register_attribute("c2", c2.type(), c2._ivalue(), false);
|
|
|
|
|
|
|
|
|
|
// clone parent
|
|
|
|
|
Module p2 = p.clone();
|
|
|
|
|
// check the two child module has the same ClassType
|
|
|
|
|
ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
|
|
|
|
|
// but different instances
|
|
|
|
|
ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, CloneWithModuleInterface) {
|
2020-07-31 10:21:43 -07:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
|
|
|
|
|
// define a initial module with two submods share same interface
|
|
|
|
|
Module parentMod("parentMod", cu);
|
|
|
|
|
Module subMod1("subMod1", cu);
|
|
|
|
|
Module subMod2("subMod2", cu);
|
|
|
|
|
|
|
|
|
|
std::vector<at::IValue> constantTable;
|
|
|
|
|
import_libs(
|
|
|
|
|
cu,
|
|
|
|
|
"__torch__.OneInterface",
|
|
|
|
|
std::make_shared<Source>(moduleInterfaceSrc),
|
|
|
|
|
constantTable);
|
|
|
|
|
|
|
|
|
|
auto v1 = IValue(2);
|
|
|
|
|
subMod1.register_attribute("attr", IntType::get(), v1, false);
|
|
|
|
|
|
|
|
|
|
auto v2 = IValue(4);
|
|
|
|
|
subMod2.register_attribute("attr", IntType::get(), v2, false);
|
|
|
|
|
|
|
|
|
|
for (const std::string& method : subModuleMethodsSrc) {
|
|
|
|
|
subMod1.define(method, nativeResolver());
|
|
|
|
|
subMod2.define(method, nativeResolver());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
parentMod.register_attribute(
|
|
|
|
|
"subMod1",
|
|
|
|
|
cu->get_interface("__torch__.OneInterface"),
|
|
|
|
|
subMod1._ivalue());
|
|
|
|
|
parentMod.register_attribute(
|
|
|
|
|
"subMod2",
|
|
|
|
|
cu->get_interface("__torch__.OneInterface"),
|
|
|
|
|
subMod2._ivalue());
|
|
|
|
|
|
|
|
|
|
parentMod.define(parentForward, nativeResolver());
|
|
|
|
|
|
|
|
|
|
Module clonedMod = parentMod.clone();
|
|
|
|
|
|
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
|
|
|
// different type
|
|
|
|
|
ASSERT_NE(clonedMod.type(), parentMod.type());
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, Copy) {
|
2019-11-21 12:58:47 -08:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr_name = "attr";
|
|
|
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
auto v = IValue(2);
|
2020-03-26 11:15:49 -07:00
|
|
|
m.register_attribute(attr_name, IntType::get(), v, false);
|
2019-11-21 12:58:47 -08:00
|
|
|
|
|
|
|
|
Module m2 = m.clone();
|
2020-05-06 13:49:40 -07:00
|
|
|
Module m3 = m.copy();
|
|
|
|
|
|
2019-11-21 12:58:47 -08:00
|
|
|
// Make sure copy works
|
|
|
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
|
|
|
|
|
|
|
|
|
|
// clone will copy both type and data, therefore we'll have a
|
|
|
|
|
// different type
|
|
|
|
|
ASSERT_NE(m.type(), m2.type());
|
2020-05-06 13:49:40 -07:00
|
|
|
// copy only copies data, type is shared
|
2019-11-21 12:58:47 -08:00
|
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
|
|
|
|
|
|
// change value of copied instance
|
2020-03-26 11:15:49 -07:00
|
|
|
m3.register_attribute(attr_name, IntType::get(), IValue(3), false);
|
2019-11-21 12:58:47 -08:00
|
|
|
// Verify value of original instance doesn't change
|
|
|
|
|
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, DeepCopy) {
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto str_attr = "str_attr";
|
|
|
|
|
auto int_attr = "int_attr";
|
|
|
|
|
auto tensor_attr = "tensor_attr";
|
|
|
|
|
auto tensor_list_attr = "tensor_list_attr";
|
|
|
|
|
cls->addAttribute(int_attr, IntType::get());
|
|
|
|
|
cls->addAttribute(str_attr, StringType::get());
|
|
|
|
|
cls->addAttribute(tensor_attr, TensorType::get());
|
|
|
|
|
cls->addAttribute(tensor_list_attr, ListType::ofTensors());
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
c10::List<at::Tensor> list({at::rand(5), at::rand(5)});
|
|
|
|
|
m.setattr(int_attr, IValue(2));
|
|
|
|
|
m.setattr(str_attr, IValue("str"));
|
|
|
|
|
m.setattr(tensor_attr, at::randn(5));
|
|
|
|
|
m.setattr(tensor_list_attr, list);
|
|
|
|
|
|
|
|
|
|
Module m2 = m.deepcopy();
|
2020-05-06 13:49:40 -07:00
|
|
|
Module m3 = m.copy();
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
// Make sure copy works
|
|
|
|
|
ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
|
|
|
|
|
|
|
|
|
|
// Test overlaps
|
|
|
|
|
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
|
|
|
|
|
ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
|
|
|
|
|
|
2020-05-06 13:49:40 -07:00
|
|
|
// Both deepcopy and copy will preserve the type
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
ASSERT_EQ(m.type(), m2.type());
|
|
|
|
|
ASSERT_EQ(m.type(), m3.type());
|
|
|
|
|
|
|
|
|
|
// change int value of copied instances
|
|
|
|
|
m2.setattr(int_attr, IValue(3));
|
|
|
|
|
m3.setattr(int_attr, IValue(4));
|
2020-05-06 13:49:40 -07:00
|
|
|
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
// Verify value of original instance doesn't change
|
|
|
|
|
ASSERT_EQ(m.attr(int_attr).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
|
|
|
|
|
ASSERT_EQ(m3.attr(int_attr).toInt(), 4);
|
|
|
|
|
|
|
|
|
|
// change Tensor value of copied instances
|
|
|
|
|
at::Tensor t1 = m.attr(tensor_attr).toTensor();
|
|
|
|
|
at::Tensor t2 =
|
|
|
|
|
m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
|
2020-05-06 13:49:40 -07:00
|
|
|
at::Tensor t3 =
|
|
|
|
|
m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
// check copy works
|
|
|
|
|
ASSERT_TRUE(t1.equal(t2));
|
|
|
|
|
ASSERT_TRUE(t1.equal(t3));
|
|
|
|
|
|
|
|
|
|
// zero out t1
|
|
|
|
|
t1.zero_();
|
|
|
|
|
// check that t2 is not affected because it is a deep copy
|
|
|
|
|
ASSERT_TRUE(!t1.equal(t2));
|
|
|
|
|
// check that t3 is the same as t1 since it is a shallow copy
|
|
|
|
|
ASSERT_TRUE(t1.equal(t3));
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, DeepCopyString) {
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr1 = "attr1";
|
|
|
|
|
cls->addAttribute(attr1, StringType::get());
|
|
|
|
|
std::string str = "str";
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
m.setattr(attr1, str);
|
|
|
|
|
auto copied = m.deepcopy();
|
|
|
|
|
auto original_str = str;
|
2022-09-23 23:36:57 +00:00
|
|
|
ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
// check string mutation is not reflected in the copied module
|
|
|
|
|
str += "str";
|
2022-09-23 23:36:57 +00:00
|
|
|
ASSERT_EQ(copied.attr(attr1).toStringRef(), original_str);
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
}
|
|
|
|
|
|
2021-12-30 07:50:26 -08:00
|
|
|
TEST(ModuleAPITest, DeepCopyEnum) {
|
|
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto enum_attr = "enum_attr";
|
|
|
|
|
auto int_enum_type = EnumType::create(
|
|
|
|
|
"enum_class",
|
|
|
|
|
IntType::get(),
|
|
|
|
|
{{"enum_name_1", 1}, {"enum_name_2", 2}},
|
|
|
|
|
cu);
|
|
|
|
|
cls->addAttribute(enum_attr, int_enum_type);
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
m.setattr(
|
|
|
|
|
enum_attr,
|
|
|
|
|
IValue(c10::make_intrusive<ivalue::EnumHolder>(
|
|
|
|
|
int_enum_type, "enum_name_1", 1)));
|
|
|
|
|
Module m2 = m.deepcopy();
|
|
|
|
|
|
|
|
|
|
// Make sure deepcopy works
|
|
|
|
|
c10::ivalue::EnumHolder* m2_holder = m2.attr(enum_attr).toEnumHolder().get();
|
|
|
|
|
ASSERT_EQ(m2_holder->value().toInt(), 1);
|
|
|
|
|
ASSERT_EQ(m2_holder->name(), "enum_name_1");
|
|
|
|
|
ASSERT_EQ(m2_holder->type(), int_enum_type);
|
|
|
|
|
|
|
|
|
|
// Test overlaps
|
|
|
|
|
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
|
|
|
|
|
|
|
|
|
|
// Deepcopy will preserve the type
|
|
|
|
|
ASSERT_EQ(m.type(), m2.type());
|
|
|
|
|
|
|
|
|
|
// Change original, should not affect deepcopy
|
|
|
|
|
m.setattr(
|
|
|
|
|
enum_attr,
|
|
|
|
|
IValue(c10::make_intrusive<ivalue::EnumHolder>(
|
|
|
|
|
int_enum_type, "enum_name_2", 2)));
|
|
|
|
|
ASSERT_NE(
|
|
|
|
|
m.attr(enum_attr).toEnumHolder().get()->value().toInt(),
|
|
|
|
|
m2.attr(enum_attr).toEnumHolder().get()->value().toInt());
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, DeepCopyPreservesAliasing) {
|
[jit] __deepcopy__ for `RecursiveScriptModule` (#32684)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32684
Previously we have `clone` and `clone_instance`, where `clone` will clone both type
and value, and `clone_instance` only clone the value, both of them are shallow copies.
We need to re-evaluate whether we should expose them as a user facing API.
I think we should hide `clone`, but `clone_instance` might be useful as well, especially
when we are copying a model with very large weights, people might just want to do shallow copy.
This PR adds a `deepcopy` that might be useful as a user API, which deep copies the values, including
Tensor, but we didn't deepcopy `Blob`, `Capsule`, `Future` or `PyObject`.
For more discussions please see the following issue.
fixes: https://github.com/pytorch/pytorch/issues/32519
Test Plan: Imported from OSS
Differential Revision: D21220756
fbshipit-source-id: 476bf11fe82c08fac36e7457879a09f545ffdc5e
2020-04-28 18:44:29 -07:00
|
|
|
// check deepcopy preserves aliasing
|
|
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr1 = "attr1";
|
|
|
|
|
auto attr2 = "attr2";
|
|
|
|
|
auto attr3 = "attr3";
|
|
|
|
|
auto attr4 = "attr4";
|
|
|
|
|
cls->addAttribute(attr1, ListType::ofTensors());
|
|
|
|
|
cls->addAttribute(attr2, ListType::ofTensors());
|
|
|
|
|
cls->addAttribute(attr3, TensorType::get());
|
|
|
|
|
cls->addAttribute(attr4, TensorType::get());
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
auto t1 = at::rand(5);
|
|
|
|
|
auto t2 = at::rand(5);
|
|
|
|
|
auto t3 = at::rand(5);
|
|
|
|
|
auto t4 = at::rand({5, 2});
|
|
|
|
|
c10::List<at::Tensor> list1({t1, t2});
|
|
|
|
|
c10::List<at::Tensor> list2({t1, t3});
|
|
|
|
|
// first element of attr1 and attr2 are aliased
|
|
|
|
|
m.setattr(attr1, list1);
|
|
|
|
|
m.setattr(attr2, list2);
|
|
|
|
|
m.setattr(attr3, t4);
|
|
|
|
|
m.setattr(attr4, t4.view(-1));
|
|
|
|
|
|
|
|
|
|
auto copied = m.deepcopy();
|
|
|
|
|
// test tensor aliasing
|
|
|
|
|
auto copied_attr1_t1 = copied.attr(attr1).toList().get(0);
|
|
|
|
|
auto copied_attr2_t1 = copied.attr(attr2).toList().get(0);
|
|
|
|
|
ASSERT_TRUE(copied_attr1_t1.isAliasOf(copied_attr2_t1));
|
|
|
|
|
|
|
|
|
|
// test aliasing from view
|
|
|
|
|
auto copied_attr3 = copied.attr(attr3);
|
|
|
|
|
auto copied_attr4 = copied.attr(attr3);
|
|
|
|
|
ASSERT_TRUE(copied_attr3.isAliasOf(copied_attr4));
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, Constants) {
|
2020-01-04 11:07:37 -08:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
auto attr_name = "attr";
|
|
|
|
|
auto const_name = "const";
|
|
|
|
|
cls->addAttribute(attr_name, IntType::get());
|
|
|
|
|
cls->addConstant(const_name, IValue(3));
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
auto v = IValue(2);
|
2020-03-26 11:15:49 -07:00
|
|
|
m.register_attribute(attr_name, IntType::get(), v, false);
|
2020-01-04 11:07:37 -08:00
|
|
|
ASSERT_TRUE(m.hasattr(attr_name));
|
|
|
|
|
ASSERT_TRUE(m.hasattr(const_name));
|
|
|
|
|
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
|
|
|
|
|
ASSERT_EQ(m.attr(const_name).toInt(), 3);
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, Parameters) {
|
2020-02-19 16:04:48 -08:00
|
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
|
|
|
auto cls = ClassType::create("foo.bar", cu, true);
|
|
|
|
|
Module m(cu, cls);
|
|
|
|
|
// Tensor parameter
|
2020-03-26 11:15:49 -07:00
|
|
|
m.register_parameter(
|
|
|
|
|
"tensor_param", at::empty({3}, at::kFloat), /* is_buffer */ false);
|
2020-02-19 16:04:48 -08:00
|
|
|
// None parameter
|
2020-03-26 11:15:49 -07:00
|
|
|
m.register_attribute(
|
|
|
|
|
"none_param", NoneType::get(), IValue(), /* is_param */ true);
|
|
|
|
|
m.register_attribute(
|
|
|
|
|
"none_param2", NoneType::get(), IValue(), /* is_param */ true);
|
2020-02-19 16:04:48 -08:00
|
|
|
auto param_list = m.parameters();
|
|
|
|
|
ASSERT_EQ(param_list.size(), 1);
|
|
|
|
|
ASSERT_TRUE(m.hasattr("tensor_param"));
|
|
|
|
|
ASSERT_TRUE(m.hasattr("none_param"));
|
|
|
|
|
ASSERT_TRUE(m.hasattr("none_param2"));
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, Define) {
|
|
|
|
|
Module m("m");
|
|
|
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
|
|
|
m.define(R"(
|
|
|
|
|
def add_it(self, x, b : int = 4):
|
|
|
|
|
return self.foo + x + b
|
|
|
|
|
)");
|
|
|
|
|
auto result = m.run_method("add_it", torch::ones({}));
|
|
|
|
|
AT_ASSERT(result.toTensor().item<float>() == 6);
|
|
|
|
|
}
|
|
|
|
|
|
2021-02-18 00:15:37 -08:00
|
|
|
TEST(ModuleAPITest, Freezing) {
|
|
|
|
|
Module m("m");
|
|
|
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
|
|
|
m.define(R"(
|
|
|
|
|
def forward(self, x, b : int = 4):
|
|
|
|
|
return self.foo + x + b
|
|
|
|
|
)");
|
|
|
|
|
m.eval();
|
2022-01-27 15:33:39 -08:00
|
|
|
auto forward_g = m.get_method("forward").graph();
|
|
|
|
|
testing::FileCheck().check("GetAttr")->run(*forward_g);
|
|
|
|
|
|
|
|
|
|
// Removal of GetAttr is done by freezing
|
2021-02-18 00:15:37 -08:00
|
|
|
auto frozen_mod = torch::jit::freeze(m);
|
2022-01-27 15:33:39 -08:00
|
|
|
forward_g = frozen_mod.get_method("forward").graph();
|
2021-02-18 00:15:37 -08:00
|
|
|
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
2021-05-15 15:48:42 -07:00
|
|
|
|
2022-01-27 15:33:39 -08:00
|
|
|
// If no training mode is set, the module is NOT frozen by OFI
|
2021-05-15 15:48:42 -07:00
|
|
|
auto frozen_mod2 = torch::jit::optimize_for_inference(m);
|
2022-01-27 15:33:39 -08:00
|
|
|
forward_g = frozen_mod2.get_method("forward").graph();
|
|
|
|
|
testing::FileCheck().check("GetAttr")->run(*forward_g);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(ModuleAPITest, OfiFreezesTraining) {
|
|
|
|
|
Module m("m");
|
|
|
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
|
|
|
m.define(R"(
|
|
|
|
|
def forward(self, x, b : int = 4):
|
|
|
|
|
return self.foo + x + b
|
|
|
|
|
)");
|
|
|
|
|
m.register_attribute("training", BoolType::get(), true);
|
|
|
|
|
m.eval();
|
|
|
|
|
|
|
|
|
|
// Before freezing, we have a GetAttr check
|
|
|
|
|
auto forward_g = m.get_method("forward").graph();
|
|
|
|
|
testing::FileCheck().check("GetAttr")->run(*forward_g);
|
|
|
|
|
|
|
|
|
|
// Demonstrate that freezing happens when OFI is called
|
|
|
|
|
// Removal of GetAttr is done by freezing, but only when training
|
|
|
|
|
// attribute is set
|
|
|
|
|
auto frozen_mod = torch::jit::optimize_for_inference(m);
|
2021-05-15 15:48:42 -07:00
|
|
|
forward_g = frozen_mod.get_method("forward").graph();
|
|
|
|
|
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
|
2021-02-18 00:15:37 -08:00
|
|
|
}
|
|
|
|
|
|
2023-10-02 20:13:40 +00:00
|
|
|
TEST(ModuleAPITest, OfiFreezesNoForward) {
|
|
|
|
|
Module m("m");
|
|
|
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
|
|
|
m.define(R"(
|
|
|
|
|
def bar(self, x, b : int = 4):
|
|
|
|
|
return self.foo + x + b
|
|
|
|
|
)");
|
|
|
|
|
m.eval();
|
|
|
|
|
|
|
|
|
|
// OFI is called without the presence of forward methods
|
|
|
|
|
auto frozen_mod =
|
|
|
|
|
torch::jit::optimize_for_inference(m, std::vector<std::string>{"bar"});
|
|
|
|
|
ASSERT_EQ(
|
|
|
|
|
m.run_method("bar", torch::ones({})).toTensor().item<float>(),
|
|
|
|
|
frozen_mod.run_method("bar", torch::ones({})).toTensor().item<float>());
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-25 11:35:39 -07:00
|
|
|
TEST(ModuleAPITest, To_CUDA) {
|
|
|
|
|
Module m("test");
|
|
|
|
|
{
|
|
|
|
|
// test cuda to cpu for params and buffers
|
|
|
|
|
m.register_parameter("foo", torch::ones({}, at::kCUDA), false);
|
|
|
|
|
m.register_buffer("bar", torch::ones({}, at::kCUDA));
|
|
|
|
|
|
|
|
|
|
m.to(at::kCUDA);
|
|
|
|
|
m.to(at::kCPU);
|
|
|
|
|
AT_ASSERT(m.attr("foo").toTensor().device().is_cpu());
|
|
|
|
|
AT_ASSERT(m.attr("bar").toTensor().device().is_cpu());
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
// test cpu to cuda for params and buffers
|
|
|
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
|
|
|
m.register_buffer("bar", torch::ones({}));
|
|
|
|
|
|
|
|
|
|
m.to(at::kCUDA);
|
|
|
|
|
AT_ASSERT(m.attr("foo").toTensor().device().is_cuda());
|
|
|
|
|
AT_ASSERT(m.attr("bar").toTensor().device().is_cuda());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-21 12:58:47 -08:00
|
|
|
} // namespace jit
|
|
|
|
|
} // namespace torch
|