Adds dynamic versioning pattern (#40279)

Summary:
BC NOTE:

This change makes it so modules saved with torch.jit.save in PyTorch 1.6 can be loaded by previous versions of PyTorch unless they use torch.div or (soon) torch.full. It also lets tensors saved using torch.save be loaded by previous versions. So this is the opposite of BC-breaking, but I'm using that label to highlight this issue since we don't have a "BC-improving" label.

PR NOTE:
When an operator's semantics change in PyTorch we want to do two things:

1) Preserve the semantics of older serialized Torchscript programs that use the operator
2) Ensure the new semantics are respected

Historically, this meant writing a Versioned Symbol that would remap older versions of the operator into current PyTorch code (1), and bumping the produced file format version (2). Unfortunately, bumping the produced file format version is a nuclear option for ensuring semantics are respected, since it also prevents older versions of PyTorch from loading anything (even tensors!) from newer versions.

Dynamic versioning addresses the nuclear consequences of bumping the produced file format version by only bumping it when necessary. That is, when an operator with changed semantics is detected in the serialized Torchscript. This will prevent Torchscript programs that use the changed operator from loading on earlier versions of PyTorch, as desired, but will have no impact on programs that don't use the changed operator.

Note that this change is only applicable when using torch.jit.save and torch.jit.load. torch.save pickles the given object using pickle (by default), which saves a function's Python directly.

No new tests for this behavior are added since the existing tests for versioned division in test_save_load already validate that models with div are loaded correctly at version 4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40279

Reviewed By: dzhulgakov

Differential Revision: D22168291

Pulled By: mruberry

fbshipit-source-id: e71d6380e727e25123c7eedf6d80e5d7f1fe9f95
This commit is contained in:
Mike Ruberry
2020-06-24 12:39:42 -07:00
committed by Facebook GitHub Bot
parent a2e1a948a4
commit e66445878d
11 changed files with 143 additions and 22 deletions

View File

@@ -4,6 +4,7 @@
#include <istream>
#include <ostream>
#include <fstream>
#include <algorithm>
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>
@@ -275,7 +276,8 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)>& writer_func)
: archive_name_("archive"), writer_func_(writer_func) {
: archive_name_("archive"),
writer_func_(writer_func) {
setup(archive_name_);
}
@@ -303,10 +305,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
valid("initializing archive ", file_name.c_str());
}
std::string version = c10::to_string(kProducedFileFormatVersion);
version.push_back('\n');
writeRecord("version", version.c_str(), version.size());
void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
version_ = std::max(version, version_);
}
void PyTorchStreamWriter::writeRecord(
@@ -339,8 +341,14 @@ void PyTorchStreamWriter::writeRecord(
}
void PyTorchStreamWriter::writeEndOfFile() {
// Rewrites version info
std::string version = c10::to_string(version_);
version.push_back('\n');
writeRecord("version", version.c_str(), version.size());
AT_ASSERT(!finalized_);
finalized_ = true;
mz_zip_writer_finalize_archive(ar_.get());
mz_zip_writer_end(ar_.get());
valid("writing central directory for archive ", archive_name_.c_str());

View File

@@ -94,14 +94,44 @@ constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
// Versions (i.e. why was the version number bumped?)
// Note [Dynamic Versions and torch.jit.save vs. torch.save]
//
// Our versioning scheme has a "produced file format version" which
// describes how an archive is to be read. The version written in an archive
// is at least this current produced file format version, but may be greater
// if it includes certain symbols. We refer to these conditional versions
// as "dynamic," since they are identified at runtime.
//
// Dynamic versioning is useful when an operator's semantics are updated.
// When using torch.jit.save we want those semantics to be preserved. If
// we bumped the produced file format version on every change, however,
// then older versions of PyTorch couldn't read even simple archives, like
// a single tensor, from newer versions of PyTorch. Instead, we
// assign dynamic versions to these changes that override the
// produced file format version as needed. That is, when the semantics
// of torch.div changed it was assigned dynamic version 4, and when
// torch.jit.saving modules that use torch.div those archives also have
// (at least) version 4. This prevents earlier versions of PyTorch
// from accidentally performing the wrong kind of division. Modules
// that don't use torch.div or other operators with dynamic versions
// can write the produced file format version, and these programs will
// run as expected on earlier versions of PyTorch.
//
// While torch.jit.save attempts to preserve operator semantics,
// torch.save does not. torch.save is analogous to pickling Python, so
// a function that uses torch.div will have different behavior if torch.saved
// and torch.loaded across PyTorch versions. From a technical perspective,
// torch.save ignores dynamic versioning.
// 1. Initial version
// 2. Removed op_version_set version numbers
// 3. Added type tags to pickle serialization of container types
// 4. Stopped integer division using torch.div
// 4. (Dynamic) Stopped integer division using torch.div
// (a versioned symbol preserves the historic behavior of versions 1--3)
// 5. Stops torch.full inferring a floating point dtype
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
// when given bool or integer fill values.
constexpr uint64_t kProducedFileFormatVersion = 0x5L;
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
// Writer-specific constants
constexpr uint64_t kFieldAlignment = 64;
@@ -144,6 +174,8 @@ class CAFFE2_API PyTorchStreamWriter final {
explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)>& writer_func);
void setMinVersion(const uint64_t version);
void writeRecord(
const std::string& name,
const void* data,
@@ -171,6 +203,7 @@ class CAFFE2_API PyTorchStreamWriter final {
std::string padding_;
std::ofstream file_stream_;
std::function<size_t(const void*, size_t)> writer_func_;
uint64_t version_ = kProducedFileFormatVersion;
bool finalized_ = false;
bool err_seen_ = false;
friend size_t ostream_write_func(

View File

@@ -1,6 +1,5 @@
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <sstream>
#include <torch/csrc/jit/serialization/export.h>
@@ -8,9 +7,42 @@
#include <torch/csrc/jit/serialization/import_source.h>
#include <torch/torch.h>
#include "caffe2/serialize/istream_adapter.h"
namespace torch {
namespace jit {
// Tests that an extra file written explicitly has precedence over
// extra files written by a hook
// TODO: test for the warning, too
void testExtraFilesHookPreference() {
const auto script = R"JIT(
def forward(self):
x = torch.rand(5, 5)
x = x.mm(x)
return x
)JIT";
auto module =
std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
module->define(script);
std::ostringstream oss;
std::unordered_map<std::string, std::string> extra_files;
extra_files["metadata.json"] = "abc";
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
return {{"metadata.json", "def"}};
});
module->save(oss, extra_files);
SetExportModuleExtraFilesHook(nullptr);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
}
void testSaveExtraFilesHook() {
// no secrets
{

View File

@@ -66,6 +66,7 @@ namespace jit {
_(QualifiedName) \
_(ClassImport) \
_(ScriptObject) \
_(ExtraFilesHookPreference) \
_(SaveExtraFilesHook) \
_(TypeTags) \
_(DCE) \

View File

@@ -76,6 +76,12 @@ static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
{0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
});
static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
{aten::div, 4},
{aten::div_, 4},
{aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
});
Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
auto it = symbol_range_map.find(name);
if (it == symbol_range_map.end()) {
@@ -90,5 +96,14 @@ Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
return name;
}
uint64_t get_min_version_for_kind(const NodeKind& kind) {
auto it = kind_min_version_map.find(kind);
if (it == kind_min_version_map.end()) {
return 0;
}
return it->second;
}
} // namespace jit
} // namespace torch

View File

@@ -14,5 +14,9 @@ namespace jit {
TORCH_API Symbol
get_symbol_for_version(const Symbol name, const uint64_t version);
// Maps the given kind to the minimum version that supports it.
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
} // namespace jit
} // namespace torch

View File

@@ -189,6 +189,11 @@ class ScriptModuleSerializer {
if (bytecode_format) {
writeByteCode(module);
}
// Acquires and sets minimum (dynamic) version
for (auto& item : file_streams_) {
writer_.setMinVersion(item.value().minVersion());
}
}
private:
@@ -234,6 +239,17 @@ class ScriptModuleSerializer {
if (hook) {
ExtraFilesMap hook_files = hook(module);
for (const auto& kv : hook_files) {
// Checks if the hooked file is already written in extra files,
// if so, skips it and warns
if (extra_files.find(kv.first) != extra_files.end()) {
TORCH_WARN_ONCE(
"An extra files hook attempted to write ",
kv.first,
" but ",
"this is already written in extra files and so will be skipped. ",
"This warning will only appear once per process.");
continue;
}
const std::string key = "extra/" + kv.first;
writer_.writeRecord(key, kv.second.data(), kv.second.size());
}

View File

@@ -4,11 +4,14 @@
#include <c10/util/StringUtil.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/frontend/versioned_symbols.h>
#include <torch/csrc/jit/ir/attributes.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/resource_guard.h>
#include <algorithm>
using c10::QualifiedName;
namespace torch {
@@ -694,9 +697,15 @@ struct PythonPrintImpl {
}
}
void checkVersion(const Node* const node) {
min_version_ =
std::max(min_version_, get_min_version_for_kind(node->kind()));
}
void printNode(Node* node, bool print_const) {
WithSourceRange guard(&source_range_stack_, node);
scanTypeDependencies(node);
checkVersion(node);
if (!print_const && node->kind() == prim::Constant)
return;
switch (node->kind()) {
@@ -1415,6 +1424,9 @@ struct PythonPrintImpl {
// when we print this, should we error if the resulting output would
// not be able to be reparsed?
bool enforce_importable_;
// The least version that supports all printed ops
uint64_t min_version_ = 0;
};
PythonPrint::PythonPrint(
@@ -1448,6 +1460,10 @@ const SourceRangeRecords& PythonPrint::ranges() const {
return pImpl->body_.ranges();
}
uint64_t PythonPrint::minVersion() const {
return pImpl->min_version_;
}
PythonPrint::~PythonPrint() = default;
} // namespace jit

View File

@@ -24,6 +24,7 @@ struct TORCH_API PythonPrint {
std::string str() const;
const SourceRangeRecords& ranges() const;
uint64_t minVersion() const;
~PythonPrint();

View File

@@ -156,13 +156,15 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
containing a file name.
_extra_files: Map from filename to contents which will be stored as part of `f`.
.. warning::
If you are using Python 2, `save` does NOT support ``StringIO.StringIO``
as a valid file-like object. This is because the write method should
return the number of bytes written; ``StringIO.write()`` does not do
this.
Please use something like ``io.BytesIO`` instead.
.. note::
torch.jit.save attempts to preserve the behavior of some operators
across versions. For example, dividing two integer tensors in
PyTorch 1.5 performed floor division, and if the module
containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
its division behavior will be preserved. The same module saved in
PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
behavior of division changed in 1.6, and 1.5 does not know how to
replicate the 1.6 behavior.
Example:

View File

@@ -342,13 +342,6 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne
.. note::
A common PyTorch convention is to save tensors using .pt file extension.
.. warning::
If you are using Python 2, :func:`torch.save` does NOT support :class:`StringIO.StringIO`
as a valid file-like object. This is because the write method should return
the number of bytes written; :meth:`StringIO.write()` does not do this.
Please use something like :class:`io.BytesIO` instead.
.. note::
The 1.6 release of PyTorch switched ``torch.save`` to use a new
zipfile-based file format. ``torch.load`` still retains the ability to