From 23b2fba79a6d2baadbb528b58ce6adb0ea929976 Mon Sep 17 00:00:00 2001 From: davidriazati Date: Tue, 10 Mar 2020 18:29:01 -0700 Subject: [PATCH] [jit] Add type tags to lists/dicts in pickle (#33255) Summary: Stacked PRs * #33474 - [jit] Remove list specializations from pickler * **#33255 - [jit] Add type tags to lists/dicts in pickle** This adds a global call to `torch.jit._pickle.restore_type_tags` for lists and dicts so that we can preserve their types after serialization. ](https://our.intern.facebook.com/intern/diff/20346780/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/33255 Pulled By: driazati Differential Revision: D20346780 fbshipit-source-id: c8534954ef4adb2e3c880401acbee30cd284f3db --- caffe2/serialize/inline_container.h | 2 +- test/cpp/jit/test_save_load.cpp | 39 ++++++++++++ test/cpp/jit/tests.h | 1 + torch/csrc/jit/mobile/type_parser.cpp | 1 + torch/csrc/jit/serialization/import.cpp | 2 + torch/csrc/jit/serialization/pickler.cpp | 69 ++++++++++++++++------ torch/csrc/jit/serialization/pickler.h | 5 ++ torch/csrc/jit/serialization/unpickler.cpp | 37 +++++++++++- torch/csrc/jit/serialization/unpickler.h | 30 +++++++++- torch/jit/_pickle.py | 13 ++++ 10 files changed, 178 insertions(+), 21 deletions(-) diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index f1b158f9a5e..fbf7deb6b10 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -91,7 +91,7 @@ namespace caffe2 { namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; -constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L; +constexpr uint64_t kMaxSupportedFileFormatVersion = 0x3L; constexpr uint64_t kProducedFileFormatVersion = 0x2L; // Writer-specific constants diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp index 1302d7e6aa3..598cd5b2282 100644 --- a/test/cpp/jit/test_save_load.cpp +++ b/test/cpp/jit/test_save_load.cpp @@ -57,5 +57,44 @@ void testSaveExtraFilesHook() { } } + +// TODO: Re-enable when add_type_tags is true +void testTypeTags() { +// auto list = c10::List>(); +// list.push_back(c10::List({1, 2, 3})); +// list.push_back(c10::List({4, 5, 6})); +// +// auto dict = c10::Dict(); +// dict.insert("Hello", torch::ones({2, 2})); +// +// auto dict_list = c10::List>(); +// for (size_t i = 0; i < 5; i++) { +// auto another_dict = c10::Dict(); +// another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2})); +// dict_list.push_back(another_dict); +// } +// +// auto tuple = std::tuple(2, "hi"); +// +// struct TestItem { +// IValue value; +// TypePtr expected_type; +// }; +// std::vector items = { +// {list, ListType::create(ListType::create(IntType::get()))}, +// {2, IntType::get()}, +// {dict, DictType::create(StringType::get(), TensorType::get())}, +// {dict_list, ListType::create(DictType::create(StringType::get(), TensorType::get()))}, +// {tuple, TupleType::create({IntType::get(), StringType::get()})} +// }; +// +// for (auto item : items) { +// auto bytes = torch::pickle_save(item.value); +// auto loaded = torch::pickle_load(bytes); +// ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type)); +// ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type())); +// } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index a0ae57d2b34..08b7e628313 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -66,6 +66,7 @@ namespace jit { _(ProfiledTensorTypeHashing) \ _(ScriptObject) \ _(SaveExtraFilesHook) \ + _(TypeTags) \ _(DCE) \ _(CustomFusionNestedBlocks) \ _(ClassDerive) \ diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 3fd698d92ca..694026d53c4 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -1,5 +1,6 @@ #include #include +#include #include namespace torch { diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index a843f0d06fa..54c145205ef 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -91,6 +91,7 @@ IValue readArchiveAndTensors( obj_loader ? std::move(*obj_loader) : nullptr, std::move(read_record), device); + unpickler.set_version(stream_reader.version()); return unpickler.parse_ivalue(); } @@ -158,6 +159,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { // type and may access the tags. Since setstate has a known input type, we // can correctly restore the tags now by apply the input type of set_state // to the state object being passed. + // TODO: Remove once [serialization type tags] is landed restoreAccurateTypeTags( input, set_state->getSchema().arguments().at(1).type()); (*set_state)({obj, input}); diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 1e36210ddf3..e5a3150a5ba 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -13,6 +13,11 @@ namespace jit { using ::c10::IValue; +thread_local bool add_type_tags = false; +bool getTypeTags() { + return add_type_tags; +} + // Protocol 2 is the highest that can be decoded by Python 2 // See https://docs.python.org/3/library/pickle.html#data-stream-format constexpr static uint8_t PROTOCOL_VERSION = 2; @@ -480,25 +485,53 @@ void Pickler::pushTensorReference(const IValue& ivalue) { push(PickleOpCode::REDUCE); } -void Pickler::pushEmptyDict() { - push(PickleOpCode::EMPTY_DICT); +// startTypeTag() and endTypeTag() must be called in a pair, with 1 argument +// pushed on the stack in between them. They will add the type of a container +// ivalue to the stack as a string so we can preserve type tags across +// serialization +void Pickler::startTypeTag() { + if (getTypeTags()) { + pushGlobal("torch.jit._pickle", "restore_type_tag"); + } } + +// See startTypeTag +void Pickler::endTypeTag(const IValue& ivalue) { + if (getTypeTags()) { + TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList()); + + // Push the dict type + TORCH_INTERNAL_ASSERT(ivalue.type()); + pushString(ivalue.type()->python_str()); + + // Pop the dict and type into a tuple + push(PickleOpCode::TUPLE2); + + // Call function via reduce + push(PickleOpCode::REDUCE); + } +} + void Pickler::pushDict(const IValue& ivalue) { - pushEmptyDict(); auto dict_items = iterationOrder(ivalue.toGenericDict()); - if (dict_items.size() == 0) { - return; + + startTypeTag(); + + push(PickleOpCode::EMPTY_DICT); + + if (dict_items.size() >= 0) { + push(PickleOpCode::MARK); + + // Sort the dict for deterministic keys + for (const auto& pair : dict_items) { + pushIValue(pair.first); + pushIValue(pair.second); + } + + push(PickleOpCode::SETITEMS); } - push(PickleOpCode::MARK); - - // Sort the dict for deterministic keys - for (const auto& pair : dict_items) { - pushIValue(pair.first); - pushIValue(pair.second); - } - - push(PickleOpCode::SETITEMS); + endTypeTag(ivalue); } size_t Pickler::pushNextBinPut() { @@ -517,15 +550,17 @@ size_t Pickler::pushNextBinPut() { void Pickler::pushGenericList(const IValue& ivalue) { auto list = ivalue.toListRef(); + startTypeTag(); + + // Push the list items push(PickleOpCode::EMPTY_LIST); - push(PickleOpCode::MARK); - for (const IValue& item : list) { pushIValue(item); } - push(PickleOpCode::APPENDS); + + endTypeTag(ivalue); } void Pickler::pushTuple(const IValue& ivalue) { diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 6ed6c4937eb..9f845feafc1 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -109,6 +109,9 @@ struct WriteableTensorData { uint64_t size_; }; +void setTypeTags(bool state); +bool getTypeTags(); + class Pickler { TH_DISALLOW_COPY_AND_ASSIGN(Pickler); @@ -144,6 +147,8 @@ class Pickler { private: void pushIValueImpl(const IValue& ivalue); + void startTypeTag(); + void endTypeTag(const IValue& value); void pushBool(bool value); void pushDouble(double value); void pushGenericList(const IValue& ivalue); diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 53bb4ebd6d9..dcddc2993e3 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,6 +5,7 @@ #endif #include #include +#include #include #include "unpickler.h" @@ -146,13 +147,29 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { } } +void restoreContainerTypeTags(IValue& ivalue, TypePtr type) { + if (auto dict_type = type->cast()) { + auto dict = ivalue.toGenericDict(); + dict.unsafeSetKeyType(dict_type->getKeyType()); + dict.unsafeSetValueType(dict_type->getValueType()); + } else if (auto list_type = type->cast()) { + ivalue.toList().unsafeSetElementType(list_type->getElementType()); + } else { + AT_ERROR("Unknown type for tag restoration: " + type->python_str()); + } +} + + IValue Unpickler::parse_ivalue() { run(); TORCH_CHECK( stack_.size() == 1, "Unpickler expected 1 element on the stack, but found ", stack_.size()); - restoreAccurateTypeTagsIfPossible(stack_[0]); + if (version_ <= 2) { + // See [type tag serialization] + restoreAccurateTypeTagsIfPossible(stack_[0]); + } return stack_[0]; } @@ -462,6 +479,24 @@ void Unpickler::readGlobal( " has no tensor table\n"); stack_.emplace_back(tensor_table_->at(data.toInt())); }); + } else if (class_name == "restore_type_tag") { + globals_.emplace_back([this] { + auto data = stack_.back().toTuple()->elements(); + auto type_str = data.at(1).toStringRef(); + stack_.pop_back(); + TypePtr type = nullptr; + auto entry = type_cache_.find(type_str); + if (entry != type_cache_.end()) { + type = entry->second; + } else { + type = c10::parseType(type_str); + type_cache_[type_str] = type; + } + // TODO: Use lookahead to avoid creating the tuple and immediately + // destroying it here + restoreContainerTypeTags(data.at(0), type); + stack_.emplace_back(data.at(0)); + }); } else { TypePtr elem_type = nullptr; if (class_name == "build_intlist") { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index d41002ea2ec..afed8d6a2a5 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -1,6 +1,8 @@ #pragma once #include "pickler.h" +#include +#include namespace torch { namespace jit { @@ -30,7 +32,8 @@ class Unpickler { const std::vector* tensor_table) : reader_(reader), tensor_table_(tensor_table), - type_resolver_(std::move(type_resolver)) {} + type_resolver_(std::move(type_resolver)), + version_(caffe2::serialize::kProducedFileFormatVersion) {} // tensors inside the pickle contain meta-data, the raw tensor // dead is retrieved by calling `read_record`. @@ -45,7 +48,8 @@ class Unpickler { type_resolver_(std::move(type_resolver)), obj_loader_(std::move(obj_loader)), read_record_(std::move(read_record)), - device_(std::move(device)) {} + device_(std::move(device)), + version_(caffe2::serialize::kProducedFileFormatVersion) {} // consume the pickle stream, producing an IValue from the contents. // Type Tags: the pickler will restore the type tags on @@ -55,6 +59,18 @@ class Unpickler { // restoreAccurateTypeTags IValue parse_ivalue(); + // [type tag serialization] + // This is used to determine whether to restore type tags be recursively + // descending into the returned stack object (if version_number <= 2), or + // if version_number >= 3, to use the type strings included in the pickle + // archive for container types. By default this is set to + // `kProducedFileFormatVersion` so unless you're loading a pickle file + // from alongside a corresponding `version` file, you don't need to set + // the version manually. + void set_version(uint64_t version_number) { + version_ = version_number; + } + private: // No arguments ensures that a template argument must be specified // so that the number of bytes read / type read is explicit @@ -109,13 +125,23 @@ class Unpickler { std::vector marks_; const std::vector* tensor_table_; + // When deserializing types on lists and dicts, cache the type here + // so we don't have to parse the same type multiple times. Strings + // are already de-duplicated and replaced with BINGETs in the + // pickler, so we can just use the actual data pointer of each string. + std::unordered_map type_cache_; + // optionally nullptr, needs to be present for creating classes TypeResolver type_resolver_; ObjLoader obj_loader_; IValue empty_tuple_; + std::function read_record_; c10::optional device_; + + // See [type tag serialization] + uint64_t version_; }; void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py index 0d8992cf25e..db2982e822d 100644 --- a/torch/jit/_pickle.py +++ b/torch/jit/_pickle.py @@ -1,6 +1,12 @@ # These functions are referenced from the pickle archives produced by # ScriptModule.save() + +# These (`build_*`) functions used to be used by `pickler.cpp` to specify +# the type of the list for certain special types, but now all lists get +# a type attached and restored via `restore_type_tag` below. The legacy +# functions should stick around for backwards-compatibility. + def build_intlist(data): return data @@ -21,3 +27,10 @@ def build_tensor_from_id(data): if isinstance(data, int): # just the id, can't really do anything return data + + +def restore_type_tag(value, type_str): + # The type_ptr is used by the jit unpickler to restore the full static type + # to container types like list when they are re-loaded, but this doesn't + # matter for Python, so just return the plain value + return value