torch.Package extend PyTorchStreamWriter to track written records (#52218)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52218

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D26429794

Pulled By: Lilyjjo

fbshipit-source-id: 5f68e7991c673ada629d0370c705520243d0637a
This commit is contained in:
Lillian Johnson
2021-02-22 15:00:34 -08:00
committed by Facebook GitHub Bot
parent a39b1c42c1
commit b72a72a477
5 changed files with 20 additions and 1 deletions

View File

@@ -222,6 +222,10 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
return out;
}
const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
return files_written;
}
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
std::string ss = archive_name_plus_slash_ + name;
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
@@ -360,6 +364,7 @@ void PyTorchStreamWriter::writeRecord(
nullptr,
0);
valid("writing file ", name.c_str());
files_written.push_back(name);
}
void PyTorchStreamWriter::writeEndOfFile() {

View File

@@ -140,6 +140,8 @@ class TORCH_API PyTorchStreamWriter final {
bool compress = false);
void writeEndOfFile();
const std::vector<std::string>& getAllWrittenRecords();
bool finalized() const {
return finalized_;
}
@@ -154,6 +156,7 @@ class TORCH_API PyTorchStreamWriter final {
void setup(const std::string& file_name);
void valid(const char* what, const char* info = "");
size_t current_pos_ = 0;
std::vector<std::string> files_written;
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::string archive_name_plus_slash_;

View File

@@ -31,6 +31,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
ASSERT_EQ(written_records[0], "key1");
ASSERT_EQ(written_records[1], "key2");
writer.writeEndOfFile();
std::string the_file = oss.str();

View File

@@ -601,6 +601,8 @@ class PyTorchFileWriter(object):
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ...
def get_all_written_records(self) -> List[str]: ...
def archive_name(self) -> str: ...
...
def _jit_get_inline_everything_mode() -> _bool: ...

View File

@@ -877,7 +877,11 @@ void initJITBindings(PyObject* module) {
size_t size) {
return self.writeRecord(
name, reinterpret_cast<const char*>(data), size);
});
})
.def("archive_name", &PyTorchStreamWriter::archiveName)
.def(
"get_all_written_records",
&PyTorchStreamWriter::getAllWrittenRecords);
py::enum_<MobileOptimizerType>(m, "MobileOptimizerType")
.value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)