Files
pytorch/test/cpp/api/serialize.cpp

272 lines
7.3 KiB
C++
Raw Normal View History

#include <gtest/gtest.h>
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim/optimizer.h>
#include <torch/optim/sgd.h>
#include <torch/serialize.h>
#include <torch/types.h>
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
#include <torch/utils.h>
#include <test/cpp/api/support.h>
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
#include <cstdio>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
using namespace torch::nn;
using namespace torch::serialize;
namespace {
Sequential xor_model() {
return Sequential(
Linear(2, 8),
Functional(at::sigmoid),
Linear(8, 1),
Functional(at::sigmoid));
}
torch::Tensor save_and_load(torch::Tensor input) {
std::stringstream stream;
torch::save(input, stream);
torch::Tensor tensor;
torch::load(tensor, stream);
return tensor;
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
} // namespace
TEST(SerializeTest, Basic) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
torch::manual_seed(0);
auto x = torch::randn({5, 5});
auto y = save_and_load(x);
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
TEST(SerializeTest, BasicToFile) {
torch::manual_seed(0);
auto x = torch::randn({5, 5});
auto tempfile = torch::utils::make_tempfile();
torch::save(x, tempfile.name);
torch::Tensor y;
torch::load(y, tempfile.name);
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
}
TEST(SerializeTest, Resized) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x.resize_({5, 5});
auto y = save_and_load(x);
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
TEST(SerializeTest, Sliced) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x = x.slice(0, 1, 5);
auto y = save_and_load(x);
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
TEST(SerializeTest, NonContiguous) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
torch::manual_seed(0);
auto x = torch::randn({11, 5});
x = x.slice(1, 1, 4);
auto y = save_and_load(x);
ASSERT_TRUE(y.defined());
ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
ASSERT_TRUE(x.allclose(y));
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
TEST(SerializeTest, XOR) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
// We better be able to save and load an XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 10:39:10 -07:00
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 10:39:10 -07:00
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
ASSERT_LT(epoch, 3000);
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
epoch++;
}
auto tempfile = torch::utils::make_tempfile();
torch::save(model, tempfile.name);
torch::load(model2, tempfile.name);
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
auto loss = getLoss(model2, 100);
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 10:39:10 -07:00
ASSERT_LT(loss.item<float>(), 0.1);
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
TEST(SerializeTest, Optim) {
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
auto model1 = Linear(5, 2);
auto model2 = Linear(5, 2);
auto model3 = Linear(5, 2);
// Models 1, 2, 3 will have the same parameters.
auto model_tempfile = torch::utils::make_tempfile();
torch::save(model1, model_tempfile.name);
torch::load(model2, model_tempfile.name);
torch::load(model3, model_tempfile.name);
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 10:53:07 -08:00
auto param1 = model1->named_parameters();
auto param2 = model2->named_parameters();
auto param3 = model3->named_parameters();
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
for (const auto& p : param1) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 10:53:07 -08:00
ASSERT_TRUE(p->allclose(param2[p.key()]));
ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
// Make some optimizers with momentum (and thus state)
auto optim1 = torch::optim::SGD(
model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim2_2 = torch::optim::SGD(
model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto optim3_2 = torch::optim::SGD(
model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
auto x = torch::ones({10, 5});
auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
auto y = model->forward(x).sum();
y.backward();
optimizer.step();
};
// Do 2 steps of model1
step(optim1, model1);
step(optim1, model1);
// Do 2 steps of model 2 without saving the optimizer
step(optim2, model2);
step(optim2_2, model2);
// Do 2 steps of model 3 while saving the optimizer
step(optim3, model3);
auto optim_tempfile = torch::utils::make_tempfile();
torch::save(optim3, optim_tempfile.name);
torch::load(optim3_2, optim_tempfile.name);
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
step(optim3_2, model3);
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 10:53:07 -08:00
param1 = model1->named_parameters();
param2 = model2->named_parameters();
param3 = model3->named_parameters();
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
for (const auto& p : param1) {
Replace cursors with OrderedDict (#13427) Summary: This is a pre-cursor diff to Python <-> C++ frontend integration -- I have a follow-up PR coming for that. This PR changes the C++ frontend module interface to replace the custom "cursor"s I introduced some time ago with `OrderedDict`. I introduced cursors at the time as a convenient way of applying functions and query operations on a modules' parameters, buffers and modules, allowing things like `module.parameters().map(my_func)`. However, I noticed that (1) this functionality is easily implement-able on top of a regular data structure and (2) more importantly, using OrderedDicts is much, much easier for Python integration. This is especially true given that ScriptModule today also uses OrderedDict. Since C++ frontend modules and ScriptModules will soon too share as many implementation details as possible, it is overall the best move to ditch the custom cursor datastructure and pervasively use OrderedDict everywhere. For this I did: 1. Changed the C++ frontend module interface to more closely match the Python one by providing `parameters()`, `named_parameters()` and other methods Python provides. This is very important for the following diff which binds these into Python for inter-op with Python modules. 2. In lieu of the `Cursor::apply()` method I added `nn::Module::apply`. This again is one more unifying step between Python and C++, since Python modules have an apply function too. 3. Deleted all uses of Cursor. 4. Tidied and beefed up the `OrderedDict` class. In particular, I made `OrderedDict::Item` store an `std::pair` under the hood, because that is trivial to bind into Python and saved me a lot of headaches. `key` and `value` become methods instead of fields, which they should have been from the very start anyway because it allows exactly these kinds of changes, as per usual good software engineering principle of encapsulation. 5. Added many tests for the OrderedDict use in `nn::Module`. ebetica ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/13427 Differential Revision: D12894092 Pulled By: goldsborough fbshipit-source-id: 715770c95a9643753a1db26d7f9da9a78619a15d
2018-11-07 10:53:07 -08:00
const auto& name = p.key();
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
// Model 1 and 3 should be the same
ASSERT_TRUE(
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 10:39:10 -07:00
param1[name].norm().item<float>() == param3[name].norm().item<float>());
ASSERT_TRUE(
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
2018-09-24 10:39:10 -07:00
param1[name].norm().item<float>() != param2[name].norm().item<float>());
Protobuf serialization (#11619) Summary: This PR serves two purposes: 1. Design an abstraction over a serialization scheme for C++ modules, optimizers and tensors in general, 2. Add serialization to the ONNX/PyTorch proto format. This is currently a rough prototype I coded up today, to get quick feedback. For this I propose the following serialization interface within the C++ API: ```cpp namespace torch { namespace serialize { class Reader { public: virtual ~Reader() = default; virtual void read(const std::string& key, Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; class Writer { public: virtual ~Reader() = default; virtual void writer(const std::string& key, const Tensor& tensor, bool is_buffer = false) = 0; virtual void finish() { } }; }} // namespace torch::serialize ``` There are then subclasses of these two for (1) Cereal and (2) Protobuf (called the "DefaultWriter" and "DefaultReader" to hide the implementation details). See `torch/serialize/cereal.h` and `torch/serialize/default.h`. This abstraction and subclassing for these two allows us to: 1. Provide a cereal-less serialization forward that we can ship and iterate on going forward, 2. Provide no-friction backwards compatibility with existing C++ API uses, mainly StarCraft. The user-facing API is (conceptually): ```cpp void torch::save(const Module& module, Writer& writer); void torch::save(const Optimizer& optimizer, Writer& writer); void torch::read(Module& module, Reader& reader); void torch::read(Optimizer& optimizer, Reader& reader); ``` with implementations for both optimizers and modules that write into the `Writer` and read from the `Reader` ebetica ezyang zdevito dzhulgakov Pull Request resolved: https://github.com/pytorch/pytorch/pull/11619 Differential Revision: D9984664 Pulled By: goldsborough fbshipit-source-id: e03afaa646221546e7f93bb8dfe3558e384a5847
2018-09-20 20:36:22 -07:00
}
}
TEST(SerializeTest, XOR_CUDA) {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
if (is_cuda) {
inputs = inputs.cuda();
labels = labels.cuda();
}
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
};
auto model = xor_model();
auto model2 = xor_model();
auto model3 = xor_model();
auto optimizer = torch::optim::SGD(
model->parameters(),
torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
1e-6));
float running_loss = 1;
int epoch = 0;
while (running_loss > 0.1) {
torch::Tensor loss = getLoss(model, 4);
optimizer.zero_grad();
loss.backward();
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
ASSERT_LT(epoch, 3000);
epoch++;
}
auto tempfile = torch::utils::make_tempfile();
torch::save(model, tempfile.name);
torch::load(model2, tempfile.name);
auto loss = getLoss(model2, 100);
ASSERT_LT(loss.item<float>(), 0.1);
model2->to(torch::kCUDA);
loss = getLoss(model2, 100, true);
ASSERT_LT(loss.item<float>(), 0.1);
auto tempfile2 = torch::utils::make_tempfile();
torch::save(model2, tempfile2.name);
torch::load(model3, tempfile2.name);
loss = getLoss(model3, 100, true);
ASSERT_LT(loss.item<float>(), 0.1);
}