mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Adds dynamic versioning pattern (#40279)
Summary: BC NOTE: This change makes it so modules saved with torch.jit.save in PyTorch 1.6 can be loaded by previous versions of PyTorch unless they use torch.div or (soon) torch.full. It also lets tensors saved using torch.save be loaded by previous versions. So this is the opposite of BC-breaking, but I'm using that label to highlight this issue since we don't have a "BC-improving" label. PR NOTE: When an operator's semantics change in PyTorch we want to do two things: 1) Preserve the semantics of older serialized Torchscript programs that use the operator 2) Ensure the new semantics are respected Historically, this meant writing a Versioned Symbol that would remap older versions of the operator into current PyTorch code (1), and bumping the produced file format version (2). Unfortunately, bumping the produced file format version is a nuclear option for ensuring semantics are respected, since it also prevents older versions of PyTorch from loading anything (even tensors!) from newer versions. Dynamic versioning addresses the nuclear consequences of bumping the produced file format version by only bumping it when necessary. That is, when an operator with changed semantics is detected in the serialized Torchscript. This will prevent Torchscript programs that use the changed operator from loading on earlier versions of PyTorch, as desired, but will have no impact on programs that don't use the changed operator. Note that this change is only applicable when using torch.jit.save and torch.jit.load. torch.save pickles the given object using pickle (by default), which saves a function's Python directly. No new tests for this behavior are added since the existing tests for versioned division in test_save_load already validate that models with div are loaded correctly at version 4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40279 Reviewed By: dzhulgakov Differential Revision: D22168291 Pulled By: mruberry fbshipit-source-id: e71d6380e727e25123c7eedf6d80e5d7f1fe9f95
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a2e1a948a4
commit
e66445878d
@@ -4,6 +4,7 @@
|
||||
#include <istream>
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Backend.h>
|
||||
@@ -275,7 +276,8 @@ PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)>& writer_func)
|
||||
: archive_name_("archive"), writer_func_(writer_func) {
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func) {
|
||||
setup(archive_name_);
|
||||
}
|
||||
|
||||
@@ -303,10 +305,10 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
|
||||
mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
|
||||
valid("initializing archive ", file_name.c_str());
|
||||
}
|
||||
|
||||
std::string version = c10::to_string(kProducedFileFormatVersion);
|
||||
version.push_back('\n');
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
|
||||
version_ = std::max(version, version_);
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeRecord(
|
||||
@@ -339,8 +341,14 @@ void PyTorchStreamWriter::writeRecord(
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
// Rewrites version info
|
||||
std::string version = c10::to_string(version_);
|
||||
version.push_back('\n');
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
|
||||
AT_ASSERT(!finalized_);
|
||||
finalized_ = true;
|
||||
|
||||
mz_zip_writer_finalize_archive(ar_.get());
|
||||
mz_zip_writer_end(ar_.get());
|
||||
valid("writing central directory for archive ", archive_name_.c_str());
|
||||
|
||||
@@ -94,14 +94,44 @@ constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
|
||||
|
||||
// Versions (i.e. why was the version number bumped?)
|
||||
|
||||
// Note [Dynamic Versions and torch.jit.save vs. torch.save]
|
||||
//
|
||||
// Our versioning scheme has a "produced file format version" which
|
||||
// describes how an archive is to be read. The version written in an archive
|
||||
// is at least this current produced file format version, but may be greater
|
||||
// if it includes certain symbols. We refer to these conditional versions
|
||||
// as "dynamic," since they are identified at runtime.
|
||||
//
|
||||
// Dynamic versioning is useful when an operator's semantics are updated.
|
||||
// When using torch.jit.save we want those semantics to be preserved. If
|
||||
// we bumped the produced file format version on every change, however,
|
||||
// then older versions of PyTorch couldn't read even simple archives, like
|
||||
// a single tensor, from newer versions of PyTorch. Instead, we
|
||||
// assign dynamic versions to these changes that override the
|
||||
// produced file format version as needed. That is, when the semantics
|
||||
// of torch.div changed it was assigned dynamic version 4, and when
|
||||
// torch.jit.saving modules that use torch.div those archives also have
|
||||
// (at least) version 4. This prevents earlier versions of PyTorch
|
||||
// from accidentally performing the wrong kind of division. Modules
|
||||
// that don't use torch.div or other operators with dynamic versions
|
||||
// can write the produced file format version, and these programs will
|
||||
// run as expected on earlier versions of PyTorch.
|
||||
//
|
||||
// While torch.jit.save attempts to preserve operator semantics,
|
||||
// torch.save does not. torch.save is analogous to pickling Python, so
|
||||
// a function that uses torch.div will have different behavior if torch.saved
|
||||
// and torch.loaded across PyTorch versions. From a technical perspective,
|
||||
// torch.save ignores dynamic versioning.
|
||||
|
||||
// 1. Initial version
|
||||
// 2. Removed op_version_set version numbers
|
||||
// 3. Added type tags to pickle serialization of container types
|
||||
// 4. Stopped integer division using torch.div
|
||||
// 4. (Dynamic) Stopped integer division using torch.div
|
||||
// (a versioned symbol preserves the historic behavior of versions 1--3)
|
||||
// 5. Stops torch.full inferring a floating point dtype
|
||||
// 5. (Dynamic) Stops torch.full inferring a floating point dtype
|
||||
// when given bool or integer fill values.
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x5L;
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
|
||||
// Writer-specific constants
|
||||
constexpr uint64_t kFieldAlignment = 64;
|
||||
@@ -144,6 +174,8 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)>& writer_func);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
void writeRecord(
|
||||
const std::string& name,
|
||||
const void* data,
|
||||
@@ -171,6 +203,7 @@ class CAFFE2_API PyTorchStreamWriter final {
|
||||
std::string padding_;
|
||||
std::ofstream file_stream_;
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
uint64_t version_ = kProducedFileFormatVersion;
|
||||
bool finalized_ = false;
|
||||
bool err_seen_ = false;
|
||||
friend size_t ostream_write_func(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include <test/cpp/jit/test_base.h>
|
||||
#include <test/cpp/jit/test_utils.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
@@ -8,9 +7,42 @@
|
||||
#include <torch/csrc/jit/serialization/import_source.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// Tests that an extra file written explicitly has precedence over
|
||||
// extra files written by a hook
|
||||
// TODO: test for the warning, too
|
||||
void testExtraFilesHookPreference() {
|
||||
const auto script = R"JIT(
|
||||
def forward(self):
|
||||
x = torch.rand(5, 5)
|
||||
x = x.mm(x)
|
||||
return x
|
||||
)JIT";
|
||||
|
||||
auto module =
|
||||
std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
|
||||
module->define(script);
|
||||
std::ostringstream oss;
|
||||
std::unordered_map<std::string, std::string> extra_files;
|
||||
extra_files["metadata.json"] = "abc";
|
||||
SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
|
||||
return {{"metadata.json", "def"}};
|
||||
});
|
||||
module->save(oss, extra_files);
|
||||
SetExportModuleExtraFilesHook(nullptr);
|
||||
|
||||
std::istringstream iss(oss.str());
|
||||
caffe2::serialize::IStreamAdapter adapter{&iss};
|
||||
std::unordered_map<std::string, std::string> loaded_extra_files;
|
||||
loaded_extra_files["metadata.json"] = "";
|
||||
auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
|
||||
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
|
||||
}
|
||||
|
||||
void testSaveExtraFilesHook() {
|
||||
// no secrets
|
||||
{
|
||||
|
||||
@@ -66,6 +66,7 @@ namespace jit {
|
||||
_(QualifiedName) \
|
||||
_(ClassImport) \
|
||||
_(ScriptObject) \
|
||||
_(ExtraFilesHookPreference) \
|
||||
_(SaveExtraFilesHook) \
|
||||
_(TypeTags) \
|
||||
_(DCE) \
|
||||
|
||||
@@ -76,6 +76,12 @@ static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
|
||||
{0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
|
||||
});
|
||||
|
||||
static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
|
||||
{aten::div, 4},
|
||||
{aten::div_, 4},
|
||||
{aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
|
||||
});
|
||||
|
||||
Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
|
||||
auto it = symbol_range_map.find(name);
|
||||
if (it == symbol_range_map.end()) {
|
||||
@@ -90,5 +96,14 @@ Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
|
||||
return name;
|
||||
}
|
||||
|
||||
uint64_t get_min_version_for_kind(const NodeKind& kind) {
|
||||
auto it = kind_min_version_map.find(kind);
|
||||
if (it == kind_min_version_map.end()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@@ -14,5 +14,9 @@ namespace jit {
|
||||
TORCH_API Symbol
|
||||
get_symbol_for_version(const Symbol name, const uint64_t version);
|
||||
|
||||
// Maps the given kind to the minimum version that supports it.
|
||||
// See note [Dynamic Versions and torch.jit.save vs. torch.save]
|
||||
TORCH_API uint64_t get_min_version_for_kind(const NodeKind& kind);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@@ -189,6 +189,11 @@ class ScriptModuleSerializer {
|
||||
if (bytecode_format) {
|
||||
writeByteCode(module);
|
||||
}
|
||||
|
||||
// Acquires and sets minimum (dynamic) version
|
||||
for (auto& item : file_streams_) {
|
||||
writer_.setMinVersion(item.value().minVersion());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -234,6 +239,17 @@ class ScriptModuleSerializer {
|
||||
if (hook) {
|
||||
ExtraFilesMap hook_files = hook(module);
|
||||
for (const auto& kv : hook_files) {
|
||||
// Checks if the hooked file is already written in extra files,
|
||||
// if so, skips it and warns
|
||||
if (extra_files.find(kv.first) != extra_files.end()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"An extra files hook attempted to write ",
|
||||
kv.first,
|
||||
" but ",
|
||||
"this is already written in extra files and so will be skipped. ",
|
||||
"This warning will only appear once per process.");
|
||||
continue;
|
||||
}
|
||||
const std::string key = "extra/" + kv.first;
|
||||
writer_.writeRecord(key, kv.second.data(), kv.second.size());
|
||||
}
|
||||
|
||||
@@ -4,11 +4,14 @@
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/error_report.h>
|
||||
#include <torch/csrc/jit/frontend/versioned_symbols.h>
|
||||
#include <torch/csrc/jit/ir/attributes.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/ir_views.h>
|
||||
#include <torch/csrc/jit/resource_guard.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using c10::QualifiedName;
|
||||
|
||||
namespace torch {
|
||||
@@ -694,9 +697,15 @@ struct PythonPrintImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void checkVersion(const Node* const node) {
|
||||
min_version_ =
|
||||
std::max(min_version_, get_min_version_for_kind(node->kind()));
|
||||
}
|
||||
|
||||
void printNode(Node* node, bool print_const) {
|
||||
WithSourceRange guard(&source_range_stack_, node);
|
||||
scanTypeDependencies(node);
|
||||
checkVersion(node);
|
||||
if (!print_const && node->kind() == prim::Constant)
|
||||
return;
|
||||
switch (node->kind()) {
|
||||
@@ -1415,6 +1424,9 @@ struct PythonPrintImpl {
|
||||
// when we print this, should we error if the resulting output would
|
||||
// not be able to be reparsed?
|
||||
bool enforce_importable_;
|
||||
|
||||
// The least version that supports all printed ops
|
||||
uint64_t min_version_ = 0;
|
||||
};
|
||||
|
||||
PythonPrint::PythonPrint(
|
||||
@@ -1448,6 +1460,10 @@ const SourceRangeRecords& PythonPrint::ranges() const {
|
||||
return pImpl->body_.ranges();
|
||||
}
|
||||
|
||||
uint64_t PythonPrint::minVersion() const {
|
||||
return pImpl->min_version_;
|
||||
}
|
||||
|
||||
PythonPrint::~PythonPrint() = default;
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@@ -24,6 +24,7 @@ struct TORCH_API PythonPrint {
|
||||
|
||||
std::string str() const;
|
||||
const SourceRangeRecords& ranges() const;
|
||||
uint64_t minVersion() const;
|
||||
|
||||
~PythonPrint();
|
||||
|
||||
|
||||
@@ -156,13 +156,15 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
|
||||
containing a file name.
|
||||
_extra_files: Map from filename to contents which will be stored as part of `f`.
|
||||
|
||||
.. warning::
|
||||
If you are using Python 2, `save` does NOT support ``StringIO.StringIO``
|
||||
as a valid file-like object. This is because the write method should
|
||||
return the number of bytes written; ``StringIO.write()`` does not do
|
||||
this.
|
||||
|
||||
Please use something like ``io.BytesIO`` instead.
|
||||
.. note::
|
||||
torch.jit.save attempts to preserve the behavior of some operators
|
||||
across versions. For example, dividing two integer tensors in
|
||||
PyTorch 1.5 performed floor division, and if the module
|
||||
containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
|
||||
its division behavior will be preserved. The same module saved in
|
||||
PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
|
||||
behavior of division changed in 1.6, and 1.5 does not know how to
|
||||
replicate the 1.6 behavior.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -342,13 +342,6 @@ def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_ne
|
||||
.. note::
|
||||
A common PyTorch convention is to save tensors using .pt file extension.
|
||||
|
||||
.. warning::
|
||||
If you are using Python 2, :func:`torch.save` does NOT support :class:`StringIO.StringIO`
|
||||
as a valid file-like object. This is because the write method should return
|
||||
the number of bytes written; :meth:`StringIO.write()` does not do this.
|
||||
|
||||
Please use something like :class:`io.BytesIO` instead.
|
||||
|
||||
.. note::
|
||||
The 1.6 release of PyTorch switched ``torch.save`` to use a new
|
||||
zipfile-based file format. ``torch.load`` still retains the ability to
|
||||
|
||||
Reference in New Issue
Block a user