In torch::save(), make padding computation faster. (#29425)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29425

This change saves roughly 5-6% in the TorchSaveSmallTensor benchmark
(torch::save() on a tensor with 64 random floats) by reusing the
padding string across records.
ghstack-source-id: 93517961

Test Plan:
Correctness: buck test mode/dev-nosan caffe2/test/...
   Benchmark buck build mode/opt experimental/jeremyl/c2/...
     buck-out/opt/gen/experimental/jeremy/c2/SerializationBench

Differential Revision: D18385731

fbshipit-source-id: 20bcbe1efd2fb7e3012dd68080542f2a74a7d4f2
This commit is contained in:
Jeremy Lilley
2019-11-08 15:01:53 -08:00
committed by Facebook Github Bot
parent 675a4cb9fb
commit e80f7506c2
2 changed files with 22 additions and 12 deletions

View File

@@ -148,12 +148,17 @@ constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
static size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
start += sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX) {
start += 2*sizeof(mz_uint64);
start += 2 * sizeof(mz_uint64);
}
if (cursor >= MZ_UINT32_MAX) {
start += sizeof(mz_uint64);
@@ -162,13 +167,16 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
size_t mod = start % kFieldAlignment;
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
size_t padding_size = next_offset - start;
std::string buf(padding_size + 4, 'Z');
size_t padding_size_plus_fbxx = padding_size + 4;
if (padding_buf.size() < padding_size_plus_fbxx) {
padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
}
// zip extra encoding (key, size_of_extra_bytes)
buf[0] = 'F';
buf[1] = 'B';
buf[2] = (uint8_t) padding_size;
buf[3] = (uint8_t) (padding_size >> 8);
return buf;
padding_buf[0] = 'F';
padding_buf[1] = 'B';
padding_buf[2] = (uint8_t)padding_size;
padding_buf[3] = (uint8_t)(padding_size >> 8);
return padding_size_plus_fbxx;
}
bool PyTorchStreamReader::hasRecord(const std::string& name) {
@@ -297,7 +305,8 @@ void PyTorchStreamWriter::writeRecord(
AT_ASSERT(!finalized_);
AT_ASSERT(!archive_name_plus_slash_.empty());
std::string full_name = archive_name_plus_slash_ + name;
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
size_t padding_size =
getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
mz_zip_writer_add_mem_ex_v2(
ar_.get(),
@@ -310,8 +319,8 @@ void PyTorchStreamWriter::writeRecord(
0,
0,
nullptr,
padding.c_str(),
padding.size(),
padding_.c_str(),
padding_size,
nullptr,
0);
valid("writing file ", name.c_str());

View File

@@ -156,6 +156,7 @@ class CAFFE2_API PyTorchStreamWriter final {
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::string archive_name_plus_slash_;
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
bool finalized_ = false;