Files
pytorch/caffe2/serialize/inline_container_test.cc
atannous 3ed1569e86 Adding serialization ID to inline container (#100994)
Summary:
In order to better track models after serialization, this change writes a serialization_id as a UUID to inline container. Having this ID enables traceability of model in saving and loading events.
serialization_id is generated as a new UUID everytime serialization takes place. It can be thought of as a model snapshot identifier at the time of serialization.

Test Plan:
```
buck2 test @//mode/dev //caffe2/caffe2/serialize:inline_container_test
```

Local tests:
```
buck2 run @//mode/opt //scripts/atannous:example_pytorch_package
buck2 run @//mode/opt //scripts/atannous:example_pytorch
buck2 run @//mode/opt //scripts/atannous:example_pytorch_script
```

```
$ unzip -l output.pt
Archive:  output.pt
  Length      Date    Time    Name
---------  ---------- -----   ----
       36  00-00-1980 00:00   output/.data/serialization_id
      358  00-00-1980 00:00   output/extra/producer_info.json
       58  00-00-1980 00:00   output/data.pkl
      261  00-00-1980 00:00   output/code/__torch__.py
      326  00-00-1980 00:00   output/code/__torch__.py.debug_pkl
        4  00-00-1980 00:00   output/constants.pkl
        2  00-00-1980 00:00   output/version
---------                     -------
     1045                     7 files
```

```
unzip -p output.pt "output/.data/serialization_id"
a9f903df-cbf6-40e3-8068-68086167ec60
```

Differential Revision: D45683657

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100994
Approved by: https://github.com/davidberard98
2023-05-17 17:08:48 +00:00

252 lines
8.3 KiB
C++

#include <cstdio>
#include <string>
#include <array>
#include <gtest/gtest.h>
#include "caffe2/serialize/inline_container.h"
#include "c10/util/irange.h"
#include <c10/util/uuid.h>
namespace caffe2 {
namespace serialize {
namespace {
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
int64_t kFieldAlignment = 64L;
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;
for (auto i: c10::irange( data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 64> data2;
for (auto i: c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
std::ofstream foo("output.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
ASSERT_TRUE(reader.hasRecord("key1"));
ASSERT_TRUE(reader.hasRecord("key2"));
ASSERT_FALSE(reader.hasRecord("key2000"));
at::DataPtr data_ptr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");
size_t off1 = reader.getRecordOffset("key1");
ASSERT_EQ(size, data1.size());
ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
ASSERT_EQ(off1 % kFieldAlignment, 0);
// inplace getRecord() test
std::vector<uint8_t> dst(size);
size_t ret = reader.getRecord("key1", dst.data(), size);
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key1", dst.data(), size, 3, [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
std::tie(data_ptr, size) = reader.getRecord("key2");
size_t off2 = reader.getRecordOffset("key2");
ASSERT_EQ(off2 % kFieldAlignment, 0);
ASSERT_EQ(size, data2.size());
ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
// inplace getRecord() test
dst.resize(size);
ret = reader.getRecord("key2", dst.data(), size);
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
// chunked getRecord() test
ret = reader.getRecord(
"key2", dst.data(), size, 3, [](void* dst, const void* src, size_t n) {
memcpy(dst, src, n);
});
ASSERT_EQ(ret, size);
ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
}
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;
for (auto i: c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 64> data2;
for (auto i: c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
std::ofstream foo("output2.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
std::vector<uint8_t> dst(data1.size());
EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error);
EXPECT_THROW(
reader.getRecord(
"key3",
dst.data(),
data1.size(),
3,
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }),
c10::Error);
// Reader should still work after throwing
EXPECT_TRUE(reader.hasRecord("key1"));
}
TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
std::ostringstream oss;
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;
for (auto i: c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 64> data2;
for (auto i: c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
std::ofstream foo("output2.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
reader.setShouldLoadDebugSymbol(false);
EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
at::DataPtr ptr;
size_t size;
std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
EXPECT_EQ(size, 0);
std::vector<uint8_t> dst(data1.size());
size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size());
EXPECT_EQ(ret, 0);
ret = reader.getRecord(
"key1.debug_pkl",
dst.data(),
data1.size(),
3,
[](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
EXPECT_EQ(ret, 0);
}
TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
std::ostringstream oss;
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
auto writer_serialization_id = writer.serializationId();
auto dup_serialization_id = uuid::generate_uuid_v4();
writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 0);
writer.writeEndOfFile();
ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
std::string the_file = oss.str();
std::ofstream foo("output3.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_EQ(reader.serializationId(), writer_serialization_id);
}
} // namespace
} // namespace serialize
} // namespace caffe2