From e66445878df2fe9fe8473bb8b467e00b05f26ea1 Mon Sep 17 00:00:00 2001 From: Mike Ruberry Date: Wed, 24 Jun 2020 12:39:42 -0700 Subject: [PATCH] Adds dynamic versioning pattern (#40279) Summary: BC NOTE: This change makes it so modules saved with torch.jit.save in PyTorch 1.6 can be loaded by previous versions of PyTorch unless they use torch.div or (soon) torch.full. It also lets tensors saved using torch.save be loaded by previous versions. So this is the opposite of BC-breaking, but I'm using that label to highlight this issue since we don't have a "BC-improving" label. PR NOTE: When an operator's semantics change in PyTorch we want to do two things: 1) Preserve the semantics of older serialized Torchscript programs that use the operator 2) Ensure the new semantics are respected Historically, this meant writing a Versioned Symbol that would remap older versions of the operator into current PyTorch code (1), and bumping the produced file format version (2). Unfortunately, bumping the produced file format version is a nuclear option for ensuring semantics are respected, since it also prevents older versions of PyTorch from loading anything (even tensors!) from newer versions. Dynamic versioning addresses the nuclear consequences of bumping the produced file format version by only bumping it when necessary. That is, when an operator with changed semantics is detected in the serialized Torchscript. This will prevent Torchscript programs that use the changed operator from loading on earlier versions of PyTorch, as desired, but will have no impact on programs that don't use the changed operator. Note that this change is only applicable when using torch.jit.save and torch.jit.load. torch.save pickles the given object using pickle (by default), which saves a function's Python directly. No new tests for this behavior are added since the existing tests for versioned division in test_save_load already validate that models with div are loaded correctly at version 4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40279 Reviewed By: dzhulgakov Differential Revision: D22168291 Pulled By: mruberry fbshipit-source-id: e71d6380e727e25123c7eedf6d80e5d7f1fe9f95 --- caffe2/serialize/inline_container.cc | 16 ++++++-- caffe2/serialize/inline_container.h | 39 +++++++++++++++++-- test/cpp/jit/test_save_load.cpp | 34 +++++++++++++++- test/cpp/jit/tests.h | 1 + torch/csrc/jit/frontend/versioned_symbols.cpp | 15 +++++++ torch/csrc/jit/frontend/versioned_symbols.h | 4 ++ .../csrc/jit/serialization/export_module.cpp | 16 ++++++++ torch/csrc/jit/serialization/python_print.cpp | 16 ++++++++ torch/csrc/jit/serialization/python_print.h | 1 + torch/jit/__init__.py | 16 ++++---- torch/serialization.py | 7 ---- 11 files changed, 143 insertions(+), 22 deletions(-) diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index ca6009fa3e4..419105c12d5 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -275,7 +276,8 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name) PyTorchStreamWriter::PyTorchStreamWriter( const std::function& writer_func) - : archive_name_("archive"), writer_func_(writer_func) { + : archive_name_("archive"), + writer_func_(writer_func) { setup(archive_name_); } @@ -303,10 +305,10 @@ void PyTorchStreamWriter::setup(const string& file_name) { mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64); valid("initializing archive ", file_name.c_str()); +} - std::string version = c10::to_string(kProducedFileFormatVersion); - version.push_back('\n'); - writeRecord("version", version.c_str(), version.size()); +void PyTorchStreamWriter::setMinVersion(const uint64_t version) { + version_ = std::max(version, version_); } void PyTorchStreamWriter::writeRecord( @@ -339,8 +341,14 @@ void PyTorchStreamWriter::writeRecord( } void PyTorchStreamWriter::writeEndOfFile() { + // Rewrites version info + std::string version = c10::to_string(version_); + version.push_back('\n'); + writeRecord("version", version.c_str(), version.size()); + AT_ASSERT(!finalized_); finalized_ = true; + mz_zip_writer_finalize_archive(ar_.get()); mz_zip_writer_end(ar_.get()); valid("writing central directory for archive ", archive_name_.c_str()); diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 4d776fab22a..ae802d664f6 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -94,14 +94,44 @@ constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L; // Versions (i.e. why was the version number bumped?) + +// Note [Dynamic Versions and torch.jit.save vs. torch.save] +// +// Our versioning scheme has a "produced file format version" which +// describes how an archive is to be read. The version written in an archive +// is at least this current produced file format version, but may be greater +// if it includes certain symbols. We refer to these conditional versions +// as "dynamic," since they are identified at runtime. +// +// Dynamic versioning is useful when an operator's semantics are updated. +// When using torch.jit.save we want those semantics to be preserved. If +// we bumped the produced file format version on every change, however, +// then older versions of PyTorch couldn't read even simple archives, like +// a single tensor, from newer versions of PyTorch. Instead, we +// assign dynamic versions to these changes that override the +// produced file format version as needed. That is, when the semantics +// of torch.div changed it was assigned dynamic version 4, and when +// torch.jit.saving modules that use torch.div those archives also have +// (at least) version 4. This prevents earlier versions of PyTorch +// from accidentally performing the wrong kind of division. Modules +// that don't use torch.div or other operators with dynamic versions +// can write the produced file format version, and these programs will +// run as expected on earlier versions of PyTorch. +// +// While torch.jit.save attempts to preserve operator semantics, +// torch.save does not. torch.save is analogous to pickling Python, so +// a function that uses torch.div will have different behavior if torch.saved +// and torch.loaded across PyTorch versions. From a technical perspective, +// torch.save ignores dynamic versioning. + // 1. Initial version // 2. Removed op_version_set version numbers // 3. Added type tags to pickle serialization of container types -// 4. Stopped integer division using torch.div +// 4. (Dynamic) Stopped integer division using torch.div // (a versioned symbol preserves the historic behavior of versions 1--3) -// 5. Stops torch.full inferring a floating point dtype +// 5. (Dynamic) Stops torch.full inferring a floating point dtype // when given bool or integer fill values. -constexpr uint64_t kProducedFileFormatVersion = 0x5L; +constexpr uint64_t kProducedFileFormatVersion = 0x3L; // Writer-specific constants constexpr uint64_t kFieldAlignment = 64; @@ -144,6 +174,8 @@ class CAFFE2_API PyTorchStreamWriter final { explicit PyTorchStreamWriter( const std::function& writer_func); + void setMinVersion(const uint64_t version); + void writeRecord( const std::string& name, const void* data, @@ -171,6 +203,7 @@ class CAFFE2_API PyTorchStreamWriter final { std::string padding_; std::ofstream file_stream_; std::function writer_func_; + uint64_t version_ = kProducedFileFormatVersion; bool finalized_ = false; bool err_seen_ = false; friend size_t ostream_write_func( diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index a29a2d03507..05940845d17 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -1,6 +1,5 @@ #include #include - #include #include @@ -8,9 +7,42 @@ #include #include +#include "caffe2/serialize/istream_adapter.h" + namespace torch { namespace jit { +// Tests that an extra file written explicitly has precedence over +// extra files written by a hook +// TODO: test for the warning, too +void testExtraFilesHookPreference() { + const auto script = R"JIT( + def forward(self): + x = torch.rand(5, 5) + x = x.mm(x) + return x + )JIT"; + + auto module = + std::make_shared("Module", std::make_shared()); + module->define(script); + std::ostringstream oss; + std::unordered_map extra_files; + extra_files["metadata.json"] = "abc"; + SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap { + return {{"metadata.json", "def"}}; + }); + module->save(oss, extra_files); + SetExportModuleExtraFilesHook(nullptr); + + std::istringstream iss(oss.str()); + caffe2::serialize::IStreamAdapter adapter{&iss}; + std::unordered_map loaded_extra_files; + loaded_extra_files["metadata.json"] = ""; + auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files); + ASSERT_EQ(loaded_extra_files["metadata.json"], "abc"); +} + void testSaveExtraFilesHook() { // no secrets { diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 0231a1f3c13..381108fe5dd 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -66,6 +66,7 @@ namespace jit { _(QualifiedName) \ _(ClassImport) \ _(ScriptObject) \ + _(ExtraFilesHookPreference) \ _(SaveExtraFilesHook) \ _(TypeTags) \ _(DCE) \ diff --git a/torch/csrc/jit/frontend/versioned_symbols.cpp b/torch/csrc/jit/frontend/versioned_symbols.cpp index 750b94753da..8e39e6f4247 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.cpp +++ b/torch/csrc/jit/frontend/versioned_symbols.cpp @@ -76,6 +76,12 @@ static std::unordered_map symbol_range_map({ {0, 4, Symbol::fromQualString("upgraders::full_0_4")}}, }); +static std::unordered_map kind_min_version_map({ + {aten::div, 4}, + {aten::div_, 4}, + {aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers) +}); + Symbol get_symbol_for_version(const Symbol name, const uint64_t version) { auto it = symbol_range_map.find(name); if (it == symbol_range_map.end()) { @@ -90,5 +96,14 @@ Symbol get_symbol_for_version(const Symbol name, const uint64_t version) { return name; } +uint64_t get_min_version_for_kind(const NodeKind& kind) { + auto it = kind_min_version_map.find(kind); + if (it == kind_min_version_map.end()) { + return 0; + } + + return it->second; +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/frontend/versioned_symbols.h b/torch/csrc/jit/frontend/versioned_symbols.h index b6e6cfdc032..314040af762 100644 --- a/torch/csrc/jit/frontend/versioned_symbols.h +++ b/torch/csrc/jit/frontend/versioned_symbols.h @@ -14,5 +14,9 @@ namespace jit { TORCH_API Symbol get_symbol_for_version(const Symbol name, const uint64_t version); +// Maps the given kind to the minimum version that supports it. +// See note [Dynamic Versions and torch.jit.save vs. torch.save] +TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind); + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index d26987f0cf6..21c028cd526 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -189,6 +189,11 @@ class ScriptModuleSerializer { if (bytecode_format) { writeByteCode(module); } + + // Acquires and sets minimum (dynamic) version + for (auto& item : file_streams_) { + writer_.setMinVersion(item.value().minVersion()); + } } private: @@ -234,6 +239,17 @@ class ScriptModuleSerializer { if (hook) { ExtraFilesMap hook_files = hook(module); for (const auto& kv : hook_files) { + // Checks if the hooked file is already written in extra files, + // if so, skips it and warns + if (extra_files.find(kv.first) != extra_files.end()) { + TORCH_WARN_ONCE( + "An extra files hook attempted to write ", + kv.first, + " but ", + "this is already written in extra files and so will be skipped. ", + "This warning will only appear once per process."); + continue; + } const std::string key = "extra/" + kv.first; writer_.writeRecord(key, kv.second.data(), kv.second.size()); } diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index f4007cbc400..01977ed97b7 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -4,11 +4,14 @@ #include #include #include +#include #include #include #include #include +#include + using c10::QualifiedName; namespace torch { @@ -694,9 +697,15 @@ struct PythonPrintImpl { } } + void checkVersion(const Node* const node) { + min_version_ = + std::max(min_version_, get_min_version_for_kind(node->kind())); + } + void printNode(Node* node, bool print_const) { WithSourceRange guard(&source_range_stack_, node); scanTypeDependencies(node); + checkVersion(node); if (!print_const && node->kind() == prim::Constant) return; switch (node->kind()) { @@ -1415,6 +1424,9 @@ struct PythonPrintImpl { // when we print this, should we error if the resulting output would // not be able to be reparsed? bool enforce_importable_; + + // The least version that supports all printed ops + uint64_t min_version_ = 0; }; PythonPrint::PythonPrint( @@ -1448,6 +1460,10 @@ const SourceRangeRecords& PythonPrint::ranges() const { return pImpl->body_.ranges(); } +uint64_t PythonPrint::minVersion() const { + return pImpl->min_version_; +} + PythonPrint::~PythonPrint() = default; } // namespace jit diff --git a/torch/csrc/jit/serialization/python_print.h b/torch/csrc/jit/serialization/python_print.h index a97d9bf39e4..33fb918dc36 100644 --- a/torch/csrc/jit/serialization/python_print.h +++ b/torch/csrc/jit/serialization/python_print.h @@ -24,6 +24,7 @@ struct TORCH_API PythonPrint { std::string str() const; const SourceRangeRecords& ranges() const; + uint64_t minVersion() const; ~PythonPrint(); diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9cad9e6e9b4..710ba42b47e 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -156,13 +156,15 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): containing a file name. _extra_files: Map from filename to contents which will be stored as part of `f`. - .. warning:: - If you are using Python 2, `save` does NOT support ``StringIO.StringIO`` - as a valid file-like object. This is because the write method should - return the number of bytes written; ``StringIO.write()`` does not do - this. - - Please use something like ``io.BytesIO`` instead. + .. note:: + torch.jit.save attempts to preserve the behavior of some operators + across versions. For example, dividing two integer tensors in + PyTorch 1.5 performed floor division, and if the module + containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6 + its division behavior will be preserved. The same module saved in + PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the + behavior of division changed in 1.6, and 1.5 does not know how to + replicate the 1.6 behavior. Example: diff --git a/torch/serialization.py b/torch/serialization.py index 52066a1f6d5..8e24a811d22 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -342,13 +342,6 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne .. note:: A common PyTorch convention is to save tensors using .pt file extension. - .. warning:: - If you are using Python 2, :func:`torch.save` does NOT support :class:`StringIO.StringIO` - as a valid file-like object. This is because the write method should return - the number of bytes written; :meth:`StringIO.write()` does not do this. - - Please use something like :class:`io.BytesIO` instead. - .. note:: The 1.6 release of PyTorch switched ``torch.save`` to use a new zipfile-based file format. ``torch.load`` still retains the ability to