mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28039 Right now, torch::save() uses std::ostream, which results in unnecessary data copies in practice. Similar for torch::load(). Adding a std::function<size_t(const void*, size_t)> as an output option, parallel to the existing filename and std::ostream apis, gives users the flexibility to emit directly to a backing store. For a simple case of appending the output to a std::string, we observe significant benchmark savings (on order of -50%), even with the minor std::function<> dispatch overhead. The main reason is that std::ostringstream effectively requires 2 extra copies of the data beyond a simple string.append lambda. We also provide a parallel api for the load(), though this one is slightly more complex due to the need to do arbitrary position reads. Test Plan: buck test mode/dev-nosan caffe2/test/... (Basic serialization test in caffe2/test/cpp/api/serialize.cpp) Benchmark in experimental/jeremyl/c2/SerializationBench.cpp, with D17823443 (1M time goes from 90ms -> 40ms, albeit with crc patch applied) Differential Revision: D17939034 fbshipit-source-id: 344cce46f74b6438cb638a8cfbeccf4e1aa882d7
69 lines
1.9 KiB
C++
69 lines
1.9 KiB
C++
#include <cstdio>
|
|
#include <string>
|
|
#include <array>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "caffe2/serialize/inline_container.h"
|
|
|
|
namespace caffe2 {
|
|
namespace serialize {
|
|
namespace {
|
|
|
|
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
|
int64_t kFieldAlignment = 64L;
|
|
|
|
std::ostringstream oss;
|
|
// write records through writers
|
|
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
|
oss.write(static_cast<const char*>(b), n);
|
|
return oss ? n : 0;
|
|
});
|
|
std::array<char, 127> data1;
|
|
|
|
for (int i = 0; i < data1.size(); ++i) {
|
|
data1[i] = data1.size() - i;
|
|
}
|
|
writer.writeRecord("key1", data1.data(), data1.size());
|
|
|
|
std::array<char, 64> data2;
|
|
for (int i = 0; i < data2.size(); ++i) {
|
|
data2[i] = data2.size() - i;
|
|
}
|
|
writer.writeRecord("key2", data2.data(), data2.size());
|
|
writer.writeEndOfFile();
|
|
|
|
std::string the_file = oss.str();
|
|
std::ofstream foo("output.zip");
|
|
foo.write(the_file.c_str(), the_file.size());
|
|
foo.close();
|
|
|
|
std::istringstream iss(the_file);
|
|
|
|
// read records through readers
|
|
PyTorchStreamReader reader(&iss);
|
|
ASSERT_TRUE(reader.hasRecord("key1"));
|
|
ASSERT_TRUE(reader.hasRecord("key2"));
|
|
ASSERT_FALSE(reader.hasRecord("key2000"));
|
|
at::DataPtr data_ptr;
|
|
int64_t size;
|
|
std::tie(data_ptr, size) = reader.getRecord("key1");
|
|
size_t off1 = reader.getRecordOffset("key1");
|
|
ASSERT_EQ(size, data1.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
|
|
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
|
|
ASSERT_EQ(off1 % kFieldAlignment, 0);
|
|
|
|
std::tie(data_ptr, size) = reader.getRecord("key2");
|
|
size_t off2 = reader.getRecordOffset("key2");
|
|
ASSERT_EQ(off2 % kFieldAlignment, 0);
|
|
|
|
ASSERT_EQ(size, data2.size());
|
|
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
|
|
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace serialize
|
|
} // namespace caffe2
|