mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
In torch::save(), make padding computation faster. (#29425)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29425 This change saves roughly 5-6% in the TorchSaveSmallTensor benchmark (torch::save() on a tensor with 64 random floats) by reusing the padding string across records. ghstack-source-id: 93517961 Test Plan: Correctness: buck test mode/dev-nosan caffe2/test/... Benchmark buck build mode/opt experimental/jeremyl/c2/... buck-out/opt/gen/experimental/jeremy/c2/SerializationBench Differential Revision: D18385731 fbshipit-source-id: 20bcbe1efd2fb7e3012dd68080542f2a74a7d4f2
This commit is contained in:
committed by
Facebook Github Bot
parent
675a4cb9fb
commit
e80f7506c2
@@ -148,12 +148,17 @@ constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
|
||||
constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
|
||||
constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
|
||||
|
||||
static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
|
||||
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
|
||||
static size_t getPadding(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf) {
|
||||
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
|
||||
sizeof(mz_uint16) * 2;
|
||||
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
||||
start += sizeof(mz_uint16) * 2;
|
||||
if (size >= MZ_UINT32_MAX) {
|
||||
start += 2*sizeof(mz_uint64);
|
||||
start += 2 * sizeof(mz_uint64);
|
||||
}
|
||||
if (cursor >= MZ_UINT32_MAX) {
|
||||
start += sizeof(mz_uint64);
|
||||
@@ -162,13 +167,16 @@ static std::string getPadding(size_t cursor, const std::string& filename, size_t
|
||||
size_t mod = start % kFieldAlignment;
|
||||
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
||||
size_t padding_size = next_offset - start;
|
||||
std::string buf(padding_size + 4, 'Z');
|
||||
size_t padding_size_plus_fbxx = padding_size + 4;
|
||||
if (padding_buf.size() < padding_size_plus_fbxx) {
|
||||
padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
|
||||
}
|
||||
// zip extra encoding (key, size_of_extra_bytes)
|
||||
buf[0] = 'F';
|
||||
buf[1] = 'B';
|
||||
buf[2] = (uint8_t) padding_size;
|
||||
buf[3] = (uint8_t) (padding_size >> 8);
|
||||
return buf;
|
||||
padding_buf[0] = 'F';
|
||||
padding_buf[1] = 'B';
|
||||
padding_buf[2] = (uint8_t)padding_size;
|
||||
padding_buf[3] = (uint8_t)(padding_size >> 8);
|
||||
return padding_size_plus_fbxx;
|
||||
}
|
||||
|
||||
bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
@@ -297,7 +305,8 @@ void PyTorchStreamWriter::writeRecord(
|
||||
AT_ASSERT(!finalized_);
|
||||
AT_ASSERT(!archive_name_plus_slash_.empty());
|
||||
std::string full_name = archive_name_plus_slash_ + name;
|
||||
std::string padding = getPadding(ar_->m_archive_size, full_name, size);
|
||||
size_t padding_size =
|
||||
getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
mz_zip_writer_add_mem_ex_v2(
|
||||
ar_.get(),
|
||||
@@ -310,8 +319,8 @@ void PyTorchStreamWriter::writeRecord(
|
||||
0,
|
||||
0,
|
||||
nullptr,
|
||||
padding.c_str(),
|
||||
padding.size(),
|
||||
padding_.c_str(),
|
||||
padding_size,
|
||||
nullptr,
|
||||
0);
|
||||
valid("writing file ", name.c_str());
|
||||
|
||||
@@ -156,6 +156,7 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::string archive_name_plus_slash_;
|
||||
std::string padding_;
|
||||
std::ofstream file_stream_;
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
bool finalized_ = false;
|
||||
|
||||
Reference in New Issue
Block a user