[jit] Add type tags to lists/dicts in pickle (#33255)

Summary:
Stacked PRs
 * #33474 - [jit] Remove list specializations from pickler
 * **#33255 - [jit] Add type tags to lists/dicts in pickle**

This adds a global call to `torch.jit._pickle.restore_type_tags` for
lists and dicts so that we can preserve their types after serialization.
](https://our.intern.facebook.com/intern/diff/20346780/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33255

Pulled By: driazati

Differential Revision: D20346780

fbshipit-source-id: c8534954ef4adb2e3c880401acbee30cd284f3db
This commit is contained in:
davidriazati
2020-03-10 18:29:01 -07:00
committed by Facebook Github Bot
parent 4167db11f7
commit 23b2fba79a
10 changed files with 178 additions and 21 deletions

View File

@@ -91,7 +91,7 @@ namespace caffe2 {
namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x3L;
constexpr uint64_t kProducedFileFormatVersion = 0x2L;
// Writer-specific constants

View File

@@ -57,5 +57,44 @@ void testSaveExtraFilesHook() {
}
}
// TODO: Re-enable when add_type_tags is true
void testTypeTags() {
// auto list = c10::List<c10::List<int64_t>>();
// list.push_back(c10::List<int64_t>({1, 2, 3}));
// list.push_back(c10::List<int64_t>({4, 5, 6}));
//
// auto dict = c10::Dict<std::string, at::Tensor>();
// dict.insert("Hello", torch::ones({2, 2}));
//
// auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
// for (size_t i = 0; i < 5; i++) {
// auto another_dict = c10::Dict<std::string, at::Tensor>();
// another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
// dict_list.push_back(another_dict);
// }
//
// auto tuple = std::tuple<int, std::string>(2, "hi");
//
// struct TestItem {
// IValue value;
// TypePtr expected_type;
// };
// std::vector<TestItem> items = {
// {list, ListType::create(ListType::create(IntType::get()))},
// {2, IntType::get()},
// {dict, DictType::create(StringType::get(), TensorType::get())},
// {dict_list, ListType::create(DictType::create(StringType::get(), TensorType::get()))},
// {tuple, TupleType::create({IntType::get(), StringType::get()})}
// };
//
// for (auto item : items) {
// auto bytes = torch::pickle_save(item.value);
// auto loaded = torch::pickle_load(bytes);
// ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type));
// ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type()));
// }
}
} // namespace jit
} // namespace torch

View File

@@ -66,6 +66,7 @@ namespace jit {
_(ProfiledTensorTypeHashing) \
_(ScriptObject) \
_(SaveExtraFilesHook) \
_(TypeTags) \
_(DCE) \
_(CustomFusionNestedBlocks) \
_(ClassDerive) \

View File

@@ -1,5 +1,6 @@
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/frontend/parser_constants.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <queue>
namespace torch {

View File

@@ -91,6 +91,7 @@ IValue readArchiveAndTensors(
obj_loader ? std::move(*obj_loader) : nullptr,
std::move(read_record),
device);
unpickler.set_version(stream_reader.version());
return unpickler.parse_ivalue();
}
@@ -158,6 +159,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
// type and may access the tags. Since setstate has a known input type, we
// can correctly restore the tags now by apply the input type of set_state
// to the state object being passed.
// TODO: Remove once [serialization type tags] is landed
restoreAccurateTypeTags(
input, set_state->getSchema().arguments().at(1).type());
(*set_state)({obj, input});

View File

@@ -13,6 +13,11 @@ namespace jit {
using ::c10::IValue;
thread_local bool add_type_tags = false;
bool getTypeTags() {
return add_type_tags;
}
// Protocol 2 is the highest that can be decoded by Python 2
// See https://docs.python.org/3/library/pickle.html#data-stream-format
constexpr static uint8_t PROTOCOL_VERSION = 2;
@@ -480,25 +485,53 @@ void Pickler::pushTensorReference(const IValue& ivalue) {
push<PickleOpCode>(PickleOpCode::REDUCE);
}
void Pickler::pushEmptyDict() {
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
// startTypeTag() and endTypeTag() must be called in a pair, with 1 argument
// pushed on the stack in between them. They will add the type of a container
// ivalue to the stack as a string so we can preserve type tags across
// serialization
void Pickler::startTypeTag() {
if (getTypeTags()) {
pushGlobal("torch.jit._pickle", "restore_type_tag");
}
}
// See startTypeTag
void Pickler::endTypeTag(const IValue& ivalue) {
if (getTypeTags()) {
TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList());
// Push the dict type
TORCH_INTERNAL_ASSERT(ivalue.type());
pushString(ivalue.type()->python_str());
// Pop the dict and type into a tuple
push<PickleOpCode>(PickleOpCode::TUPLE2);
// Call function via reduce
push<PickleOpCode>(PickleOpCode::REDUCE);
}
}
void Pickler::pushDict(const IValue& ivalue) {
pushEmptyDict();
auto dict_items = iterationOrder(ivalue.toGenericDict());
if (dict_items.size() == 0) {
return;
startTypeTag();
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
if (dict_items.size() >= 0) {
push<PickleOpCode>(PickleOpCode::MARK);
// Sort the dict for deterministic keys
for (const auto& pair : dict_items) {
pushIValue(pair.first);
pushIValue(pair.second);
}
push<PickleOpCode>(PickleOpCode::SETITEMS);
}
push<PickleOpCode>(PickleOpCode::MARK);
// Sort the dict for deterministic keys
for (const auto& pair : dict_items) {
pushIValue(pair.first);
pushIValue(pair.second);
}
push<PickleOpCode>(PickleOpCode::SETITEMS);
endTypeTag(ivalue);
}
size_t Pickler::pushNextBinPut() {
@@ -517,15 +550,17 @@ size_t Pickler::pushNextBinPut() {
void Pickler::pushGenericList(const IValue& ivalue) {
auto list = ivalue.toListRef();
startTypeTag();
// Push the list items
push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
push<PickleOpCode>(PickleOpCode::MARK);
for (const IValue& item : list) {
pushIValue(item);
}
push<PickleOpCode>(PickleOpCode::APPENDS);
endTypeTag(ivalue);
}
void Pickler::pushTuple(const IValue& ivalue) {

View File

@@ -109,6 +109,9 @@ struct WriteableTensorData {
uint64_t size_;
};
void setTypeTags(bool state);
bool getTypeTags();
class Pickler {
TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
@@ -144,6 +147,8 @@ class Pickler {
private:
void pushIValueImpl(const IValue& ivalue);
void startTypeTag();
void endTypeTag(const IValue& value);
void pushBool(bool value);
void pushDouble(double value);
void pushGenericList(const IValue& ivalue);

View File

@@ -5,6 +5,7 @@
#endif
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <string>
#include "unpickler.h"
@@ -146,13 +147,29 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
}
}
void restoreContainerTypeTags(IValue& ivalue, TypePtr type) {
if (auto dict_type = type->cast<DictType>()) {
auto dict = ivalue.toGenericDict();
dict.unsafeSetKeyType(dict_type->getKeyType());
dict.unsafeSetValueType(dict_type->getValueType());
} else if (auto list_type = type->cast<ListType>()) {
ivalue.toList().unsafeSetElementType(list_type->getElementType());
} else {
AT_ERROR("Unknown type for tag restoration: " + type->python_str());
}
}
IValue Unpickler::parse_ivalue() {
run();
TORCH_CHECK(
stack_.size() == 1,
"Unpickler expected 1 element on the stack, but found ",
stack_.size());
restoreAccurateTypeTagsIfPossible(stack_[0]);
if (version_ <= 2) {
// See [type tag serialization]
restoreAccurateTypeTagsIfPossible(stack_[0]);
}
return stack_[0];
}
@@ -462,6 +479,24 @@ void Unpickler::readGlobal(
" has no tensor table\n");
stack_.emplace_back(tensor_table_->at(data.toInt()));
});
} else if (class_name == "restore_type_tag") {
globals_.emplace_back([this] {
auto data = stack_.back().toTuple()->elements();
auto type_str = data.at(1).toStringRef();
stack_.pop_back();
TypePtr type = nullptr;
auto entry = type_cache_.find(type_str);
if (entry != type_cache_.end()) {
type = entry->second;
} else {
type = c10::parseType(type_str);
type_cache_[type_str] = type;
}
// TODO: Use lookahead to avoid creating the tuple and immediately
// destroying it here
restoreContainerTypeTags(data.at(0), type);
stack_.emplace_back(data.at(0));
});
} else {
TypePtr elem_type = nullptr;
if (class_name == "build_intlist") {

View File

@@ -1,6 +1,8 @@
#pragma once
#include "pickler.h"
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/inline_container.h>
namespace torch {
namespace jit {
@@ -30,7 +32,8 @@ class Unpickler {
const std::vector<at::Tensor>* tensor_table)
: reader_(reader),
tensor_table_(tensor_table),
type_resolver_(std::move(type_resolver)) {}
type_resolver_(std::move(type_resolver)),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// tensors inside the pickle contain meta-data, the raw tensor
// dead is retrieved by calling `read_record`.
@@ -45,7 +48,8 @@ class Unpickler {
type_resolver_(std::move(type_resolver)),
obj_loader_(std::move(obj_loader)),
read_record_(std::move(read_record)),
device_(std::move(device)) {}
device_(std::move(device)),
version_(caffe2::serialize::kProducedFileFormatVersion) {}
// consume the pickle stream, producing an IValue from the contents.
// Type Tags: the pickler will restore the type tags on
@@ -55,6 +59,18 @@ class Unpickler {
// restoreAccurateTypeTags
IValue parse_ivalue();
// [type tag serialization]
// This is used to determine whether to restore type tags be recursively
// descending into the returned stack object (if version_number <= 2), or
// if version_number >= 3, to use the type strings included in the pickle
// archive for container types. By default this is set to
// `kProducedFileFormatVersion` so unless you're loading a pickle file
// from alongside a corresponding `version` file, you don't need to set
// the version manually.
void set_version(uint64_t version_number) {
version_ = version_number;
}
private:
// No arguments ensures that a template argument must be specified
// so that the number of bytes read / type read is explicit
@@ -109,13 +125,23 @@ class Unpickler {
std::vector<size_t> marks_;
const std::vector<at::Tensor>* tensor_table_;
// When deserializing types on lists and dicts, cache the type here
// so we don't have to parse the same type multiple times. Strings
// are already de-duplicated and replaced with BINGETs in the
// pickler, so we can just use the actual data pointer of each string.
std::unordered_map<std::string, c10::TypePtr> type_cache_;
// optionally nullptr, needs to be present for creating classes
TypeResolver type_resolver_;
ObjLoader obj_loader_;
IValue empty_tuple_;
std::function<at::DataPtr(const std::string&)> read_record_;
c10::optional<at::Device> device_;
// See [type tag serialization]
uint64_t version_;
};
void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);

View File

@@ -1,6 +1,12 @@
# These functions are referenced from the pickle archives produced by
# ScriptModule.save()
# These (`build_*`) functions used to be used by `pickler.cpp` to specify
# the type of the list for certain special types, but now all lists get
# a type attached and restored via `restore_type_tag` below. The legacy
# functions should stick around for backwards-compatibility.
def build_intlist(data):
return data
@@ -21,3 +27,10 @@ def build_tensor_from_id(data):
if isinstance(data, int):
# just the id, can't really do anything
return data
def restore_type_tag(value, type_str):
# The type_ptr is used by the jit unpickler to restore the full static type
# to container types like list when they are re-loaded, but this doesn't
# matter for Python, so just return the plain value
return value