mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[jit] Add type tags to lists/dicts in pickle (#33255)
Summary: Stacked PRs * #33474 - [jit] Remove list specializations from pickler * **#33255 - [jit] Add type tags to lists/dicts in pickle** This adds a global call to `torch.jit._pickle.restore_type_tags` for lists and dicts so that we can preserve their types after serialization. ](https://our.intern.facebook.com/intern/diff/20346780/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/33255 Pulled By: driazati Differential Revision: D20346780 fbshipit-source-id: c8534954ef4adb2e3c880401acbee30cd284f3db
This commit is contained in:
committed by
Facebook Github Bot
parent
4167db11f7
commit
23b2fba79a
@@ -91,7 +91,7 @@ namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x2L;
|
||||
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x3L;
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x2L;
|
||||
|
||||
// Writer-specific constants
|
||||
|
||||
@@ -57,5 +57,44 @@ void testSaveExtraFilesHook() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: Re-enable when add_type_tags is true
|
||||
void testTypeTags() {
|
||||
// auto list = c10::List<c10::List<int64_t>>();
|
||||
// list.push_back(c10::List<int64_t>({1, 2, 3}));
|
||||
// list.push_back(c10::List<int64_t>({4, 5, 6}));
|
||||
//
|
||||
// auto dict = c10::Dict<std::string, at::Tensor>();
|
||||
// dict.insert("Hello", torch::ones({2, 2}));
|
||||
//
|
||||
// auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
|
||||
// for (size_t i = 0; i < 5; i++) {
|
||||
// auto another_dict = c10::Dict<std::string, at::Tensor>();
|
||||
// another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
|
||||
// dict_list.push_back(another_dict);
|
||||
// }
|
||||
//
|
||||
// auto tuple = std::tuple<int, std::string>(2, "hi");
|
||||
//
|
||||
// struct TestItem {
|
||||
// IValue value;
|
||||
// TypePtr expected_type;
|
||||
// };
|
||||
// std::vector<TestItem> items = {
|
||||
// {list, ListType::create(ListType::create(IntType::get()))},
|
||||
// {2, IntType::get()},
|
||||
// {dict, DictType::create(StringType::get(), TensorType::get())},
|
||||
// {dict_list, ListType::create(DictType::create(StringType::get(), TensorType::get()))},
|
||||
// {tuple, TupleType::create({IntType::get(), StringType::get()})}
|
||||
// };
|
||||
//
|
||||
// for (auto item : items) {
|
||||
// auto bytes = torch::pickle_save(item.value);
|
||||
// auto loaded = torch::pickle_load(bytes);
|
||||
// ASSERT_TRUE(loaded.type()->isSubtypeOf(item.expected_type));
|
||||
// ASSERT_TRUE(item.expected_type->isSubtypeOf(loaded.type()));
|
||||
// }
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@@ -66,6 +66,7 @@ namespace jit {
|
||||
_(ProfiledTensorTypeHashing) \
|
||||
_(ScriptObject) \
|
||||
_(SaveExtraFilesHook) \
|
||||
_(TypeTags) \
|
||||
_(DCE) \
|
||||
_(CustomFusionNestedBlocks) \
|
||||
_(ClassDerive) \
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <torch/csrc/jit/frontend/parser_constants.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <queue>
|
||||
|
||||
namespace torch {
|
||||
|
||||
@@ -91,6 +91,7 @@ IValue readArchiveAndTensors(
|
||||
obj_loader ? std::move(*obj_loader) : nullptr,
|
||||
std::move(read_record),
|
||||
device);
|
||||
unpickler.set_version(stream_reader.version());
|
||||
return unpickler.parse_ivalue();
|
||||
}
|
||||
|
||||
@@ -158,6 +159,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
|
||||
// type and may access the tags. Since setstate has a known input type, we
|
||||
// can correctly restore the tags now by apply the input type of set_state
|
||||
// to the state object being passed.
|
||||
// TODO: Remove once [serialization type tags] is landed
|
||||
restoreAccurateTypeTags(
|
||||
input, set_state->getSchema().arguments().at(1).type());
|
||||
(*set_state)({obj, input});
|
||||
|
||||
@@ -13,6 +13,11 @@ namespace jit {
|
||||
|
||||
using ::c10::IValue;
|
||||
|
||||
thread_local bool add_type_tags = false;
|
||||
bool getTypeTags() {
|
||||
return add_type_tags;
|
||||
}
|
||||
|
||||
// Protocol 2 is the highest that can be decoded by Python 2
|
||||
// See https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
constexpr static uint8_t PROTOCOL_VERSION = 2;
|
||||
@@ -480,25 +485,53 @@ void Pickler::pushTensorReference(const IValue& ivalue) {
|
||||
push<PickleOpCode>(PickleOpCode::REDUCE);
|
||||
}
|
||||
|
||||
void Pickler::pushEmptyDict() {
|
||||
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
|
||||
// startTypeTag() and endTypeTag() must be called in a pair, with 1 argument
|
||||
// pushed on the stack in between them. They will add the type of a container
|
||||
// ivalue to the stack as a string so we can preserve type tags across
|
||||
// serialization
|
||||
void Pickler::startTypeTag() {
|
||||
if (getTypeTags()) {
|
||||
pushGlobal("torch.jit._pickle", "restore_type_tag");
|
||||
}
|
||||
}
|
||||
|
||||
// See startTypeTag
|
||||
void Pickler::endTypeTag(const IValue& ivalue) {
|
||||
if (getTypeTags()) {
|
||||
TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList());
|
||||
|
||||
// Push the dict type
|
||||
TORCH_INTERNAL_ASSERT(ivalue.type());
|
||||
pushString(ivalue.type()->python_str());
|
||||
|
||||
// Pop the dict and type into a tuple
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE2);
|
||||
|
||||
// Call function via reduce
|
||||
push<PickleOpCode>(PickleOpCode::REDUCE);
|
||||
}
|
||||
}
|
||||
|
||||
void Pickler::pushDict(const IValue& ivalue) {
|
||||
pushEmptyDict();
|
||||
auto dict_items = iterationOrder(ivalue.toGenericDict());
|
||||
if (dict_items.size() == 0) {
|
||||
return;
|
||||
|
||||
startTypeTag();
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
|
||||
|
||||
if (dict_items.size() >= 0) {
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
|
||||
// Sort the dict for deterministic keys
|
||||
for (const auto& pair : dict_items) {
|
||||
pushIValue(pair.first);
|
||||
pushIValue(pair.second);
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::SETITEMS);
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
|
||||
// Sort the dict for deterministic keys
|
||||
for (const auto& pair : dict_items) {
|
||||
pushIValue(pair.first);
|
||||
pushIValue(pair.second);
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::SETITEMS);
|
||||
endTypeTag(ivalue);
|
||||
}
|
||||
|
||||
size_t Pickler::pushNextBinPut() {
|
||||
@@ -517,15 +550,17 @@ size_t Pickler::pushNextBinPut() {
|
||||
|
||||
void Pickler::pushGenericList(const IValue& ivalue) {
|
||||
auto list = ivalue.toListRef();
|
||||
startTypeTag();
|
||||
|
||||
// Push the list items
|
||||
push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
|
||||
for (const IValue& item : list) {
|
||||
pushIValue(item);
|
||||
}
|
||||
|
||||
push<PickleOpCode>(PickleOpCode::APPENDS);
|
||||
|
||||
endTypeTag(ivalue);
|
||||
}
|
||||
|
||||
void Pickler::pushTuple(const IValue& ivalue) {
|
||||
|
||||
@@ -109,6 +109,9 @@ struct WriteableTensorData {
|
||||
uint64_t size_;
|
||||
};
|
||||
|
||||
void setTypeTags(bool state);
|
||||
bool getTypeTags();
|
||||
|
||||
class Pickler {
|
||||
TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
|
||||
|
||||
@@ -144,6 +147,8 @@ class Pickler {
|
||||
|
||||
private:
|
||||
void pushIValueImpl(const IValue& ivalue);
|
||||
void startTypeTag();
|
||||
void endTypeTag(const IValue& value);
|
||||
void pushBool(bool value);
|
||||
void pushDouble(double value);
|
||||
void pushGenericList(const IValue& ivalue);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#endif
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <string>
|
||||
#include "unpickler.h"
|
||||
|
||||
@@ -146,13 +147,29 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
|
||||
}
|
||||
}
|
||||
|
||||
void restoreContainerTypeTags(IValue& ivalue, TypePtr type) {
|
||||
if (auto dict_type = type->cast<DictType>()) {
|
||||
auto dict = ivalue.toGenericDict();
|
||||
dict.unsafeSetKeyType(dict_type->getKeyType());
|
||||
dict.unsafeSetValueType(dict_type->getValueType());
|
||||
} else if (auto list_type = type->cast<ListType>()) {
|
||||
ivalue.toList().unsafeSetElementType(list_type->getElementType());
|
||||
} else {
|
||||
AT_ERROR("Unknown type for tag restoration: " + type->python_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
IValue Unpickler::parse_ivalue() {
|
||||
run();
|
||||
TORCH_CHECK(
|
||||
stack_.size() == 1,
|
||||
"Unpickler expected 1 element on the stack, but found ",
|
||||
stack_.size());
|
||||
restoreAccurateTypeTagsIfPossible(stack_[0]);
|
||||
if (version_ <= 2) {
|
||||
// See [type tag serialization]
|
||||
restoreAccurateTypeTagsIfPossible(stack_[0]);
|
||||
}
|
||||
return stack_[0];
|
||||
}
|
||||
|
||||
@@ -462,6 +479,24 @@ void Unpickler::readGlobal(
|
||||
" has no tensor table\n");
|
||||
stack_.emplace_back(tensor_table_->at(data.toInt()));
|
||||
});
|
||||
} else if (class_name == "restore_type_tag") {
|
||||
globals_.emplace_back([this] {
|
||||
auto data = stack_.back().toTuple()->elements();
|
||||
auto type_str = data.at(1).toStringRef();
|
||||
stack_.pop_back();
|
||||
TypePtr type = nullptr;
|
||||
auto entry = type_cache_.find(type_str);
|
||||
if (entry != type_cache_.end()) {
|
||||
type = entry->second;
|
||||
} else {
|
||||
type = c10::parseType(type_str);
|
||||
type_cache_[type_str] = type;
|
||||
}
|
||||
// TODO: Use lookahead to avoid creating the tuple and immediately
|
||||
// destroying it here
|
||||
restoreContainerTypeTags(data.at(0), type);
|
||||
stack_.emplace_back(data.at(0));
|
||||
});
|
||||
} else {
|
||||
TypePtr elem_type = nullptr;
|
||||
if (class_name == "build_intlist") {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "pickler.h"
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <caffe2/serialize/inline_container.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@@ -30,7 +32,8 @@ class Unpickler {
|
||||
const std::vector<at::Tensor>* tensor_table)
|
||||
: reader_(reader),
|
||||
tensor_table_(tensor_table),
|
||||
type_resolver_(std::move(type_resolver)) {}
|
||||
type_resolver_(std::move(type_resolver)),
|
||||
version_(caffe2::serialize::kProducedFileFormatVersion) {}
|
||||
|
||||
// tensors inside the pickle contain meta-data, the raw tensor
|
||||
// dead is retrieved by calling `read_record`.
|
||||
@@ -45,7 +48,8 @@ class Unpickler {
|
||||
type_resolver_(std::move(type_resolver)),
|
||||
obj_loader_(std::move(obj_loader)),
|
||||
read_record_(std::move(read_record)),
|
||||
device_(std::move(device)) {}
|
||||
device_(std::move(device)),
|
||||
version_(caffe2::serialize::kProducedFileFormatVersion) {}
|
||||
|
||||
// consume the pickle stream, producing an IValue from the contents.
|
||||
// Type Tags: the pickler will restore the type tags on
|
||||
@@ -55,6 +59,18 @@ class Unpickler {
|
||||
// restoreAccurateTypeTags
|
||||
IValue parse_ivalue();
|
||||
|
||||
// [type tag serialization]
|
||||
// This is used to determine whether to restore type tags be recursively
|
||||
// descending into the returned stack object (if version_number <= 2), or
|
||||
// if version_number >= 3, to use the type strings included in the pickle
|
||||
// archive for container types. By default this is set to
|
||||
// `kProducedFileFormatVersion` so unless you're loading a pickle file
|
||||
// from alongside a corresponding `version` file, you don't need to set
|
||||
// the version manually.
|
||||
void set_version(uint64_t version_number) {
|
||||
version_ = version_number;
|
||||
}
|
||||
|
||||
private:
|
||||
// No arguments ensures that a template argument must be specified
|
||||
// so that the number of bytes read / type read is explicit
|
||||
@@ -109,13 +125,23 @@ class Unpickler {
|
||||
std::vector<size_t> marks_;
|
||||
const std::vector<at::Tensor>* tensor_table_;
|
||||
|
||||
// When deserializing types on lists and dicts, cache the type here
|
||||
// so we don't have to parse the same type multiple times. Strings
|
||||
// are already de-duplicated and replaced with BINGETs in the
|
||||
// pickler, so we can just use the actual data pointer of each string.
|
||||
std::unordered_map<std::string, c10::TypePtr> type_cache_;
|
||||
|
||||
// optionally nullptr, needs to be present for creating classes
|
||||
TypeResolver type_resolver_;
|
||||
ObjLoader obj_loader_;
|
||||
IValue empty_tuple_;
|
||||
|
||||
|
||||
std::function<at::DataPtr(const std::string&)> read_record_;
|
||||
c10::optional<at::Device> device_;
|
||||
|
||||
// See [type tag serialization]
|
||||
uint64_t version_;
|
||||
};
|
||||
|
||||
void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# These functions are referenced from the pickle archives produced by
|
||||
# ScriptModule.save()
|
||||
|
||||
|
||||
# These (`build_*`) functions used to be used by `pickler.cpp` to specify
|
||||
# the type of the list for certain special types, but now all lists get
|
||||
# a type attached and restored via `restore_type_tag` below. The legacy
|
||||
# functions should stick around for backwards-compatibility.
|
||||
|
||||
def build_intlist(data):
|
||||
return data
|
||||
|
||||
@@ -21,3 +27,10 @@ def build_tensor_from_id(data):
|
||||
if isinstance(data, int):
|
||||
# just the id, can't really do anything
|
||||
return data
|
||||
|
||||
|
||||
def restore_type_tag(value, type_str):
|
||||
# The type_ptr is used by the jit unpickler to restore the full static type
|
||||
# to container types like list when they are re-loaded, but this doesn't
|
||||
# matter for Python, so just return the plain value
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user