Add version number to bytecode. (#36439)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36439

A proposal of versioning in bytecode, as suggested by dzhulgakov in the internal post: https://fb.workplace.com/groups/pytorch.mobile.work/permalink/590192431851054/

kProducedBytecodeVersion is added. If the model version is not the same as the number in the code, an error will be thrown.

The updated bytecode would look like below. It's a tuple of elements, where the first element is the version number.
```
(3,
 ('__torch__.m.forward',
  (('instructions',
    (('STOREN', 1, 2),
     ('DROPR', 1, 0),
     ('MOVE', 2, 0),
     ('OP', 0, 0),
     ('RET', 0, 0))),
   ('operators', (('aten::Int', 'Tensor'),)),
   ('constants', ()),
   ('types', ()),
   ('register_size', 2))))
```

Test Plan: Imported from OSS

Differential Revision: D22433532

Pulled By: iseeyuan

fbshipit-source-id: 6d62e4abe679cf91a8e18793268ad8c1d94ce746
This commit is contained in:
Martin Yuan
2020-07-08 12:28:17 -07:00
committed by Facebook GitHub Bot
parent 58d7d91f88
commit 131a0ea277
3 changed files with 48 additions and 18 deletions

View File

@@ -133,6 +133,18 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
// when given bool or integer fill values.
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
// the version we write when the archive contains bytecode.
// It must be higher or eq to kProducedFileFormatVersion.
// Because torchscript changes is likely introduce bytecode change,
// if kProducedFileFormatVersion is increased, kProducedBytecodeVersion
// should be increased too. The relationship is:
// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion
// >= kProducedFileFormatVersion
constexpr uint64_t kProducedBytecodeVersion = 0x3L;
static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
"kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");
// Writer-specific constants
constexpr uint64_t kFieldAlignment = 64;

View File

@@ -16,21 +16,18 @@
// The import process to serialize the bytecode package.
// An example for bytecode.pkl of a small mobile_module looks like:
// (('__torch__.m.add_it',
// (('instructions',
// (('STOREN', 1, 2),
// ('MOVE', 1, 0),
// ('GET_ATTR', 0, 0),
// ('MOVE', 2, 0),
// ('LOADC', 0, 0),
// ('OP', 0, 0),
// ('LOADC', 1, 0),
// ('LOADC', 0, 0),
// ('OP', 1, 0),
// ('RET', 0, 0))),
// ('operators', (('_aten::add', 'Tensor'), ('_aten::add', 'Scalar'))),
// ('constants', (1, 4)),
// ('register_size', 2))),)
// (3,
// ('__torch__.m.forward',
// (('instructions',
// (('STOREN', 1, 2),
// ('DROPR', 1, 0),
// ('MOVE', 2, 0),
// ('OP', 0, 0),
// ('RET', 0, 0))),
// ('operators', (('aten::Int', 'Tensor'),)),
// ('constants', ()),
// ('types', ()),
// ('register_size', 2))))
// Note that currently the backward compatibility is not supported by bytecode.
// This format and process need to be revisted and redesigned if we want to
@@ -87,7 +84,28 @@ void print_unsupported_ops_and_throw(
void parseMethods(
const std::vector<IValue>& vals,
mobile::CompilationUnit& mcu) {
for (const auto& element : vals) {
TORCH_CHECK(
vals.size() > 0,
"Bytecode has no elements. ");
// Initialized with the version number when kProducedBytecodeVersion was
// introduced. The old models (some of them already in production) without
// version number don't have to be re-generated.
int64_t model_version = 0x3L;
size_t method_i_start = 0;
if (vals[0].isInt()) {
model_version = vals[0].toInt();
method_i_start = 1;
}
TORCH_CHECK(
model_version == caffe2::serialize::kProducedBytecodeVersion,
"Lite Interpreter verson number does not match. ",
"The code version is ",
caffe2::serialize::kProducedBytecodeVersion,
" but the model version is ",
model_version);
for (size_t i = method_i_start; i < vals.size(); ++i) {
const auto& element = vals[i];
const auto& m_tuple = element.toTuple()->elements();
const std::string& function_name = m_tuple[0].toStringRef();
IValue table = m_tuple[1];
@@ -136,7 +154,6 @@ void parseMethods(
op_item[0].toString()->string(), op_item[1].toString()->string()));
}
}
if (!unsupported_op_names.empty()) {
print_unsupported_ops_and_throw(unsupported_op_names);
};

View File

@@ -5,7 +5,6 @@
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/python_print.h>
@@ -301,6 +300,8 @@ class ScriptModuleSerializer {
void writeByteCode(const Module& module) {
std::vector<c10::IValue> elements;
elements.emplace_back(
static_cast<int64_t>(caffe2::serialize::kProducedBytecodeVersion));
moduleMethodsTuple(module, elements);
auto telements = Tup(std::move(elements));
writeArchive("bytecode", telements);