mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
torch.Package extend PyTorchStreamWriter to track written records (#52218)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52218 Test Plan: Imported from OSS Reviewed By: suo Differential Revision: D26429794 Pulled By: Lilyjjo fbshipit-source-id: 5f68e7991c673ada629d0370c705520243d0637a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a39b1c42c1
commit
b72a72a477
@@ -222,6 +222,10 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
|
||||
return files_written;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
std::string ss = archive_name_plus_slash_ + name;
|
||||
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
|
||||
@@ -360,6 +364,7 @@ void PyTorchStreamWriter::writeRecord(
|
||||
nullptr,
|
||||
0);
|
||||
valid("writing file ", name.c_str());
|
||||
files_written.push_back(name);
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
|
||||
@@ -140,6 +140,8 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
bool compress = false);
|
||||
void writeEndOfFile();
|
||||
|
||||
const std::vector<std::string>& getAllWrittenRecords();
|
||||
|
||||
bool finalized() const {
|
||||
return finalized_;
|
||||
}
|
||||
@@ -154,6 +156,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
void setup(const std::string& file_name);
|
||||
void valid(const char* what, const char* info = "");
|
||||
size_t current_pos_ = 0;
|
||||
std::vector<std::string> files_written;
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::string archive_name_plus_slash_;
|
||||
|
||||
@@ -31,6 +31,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
data2[i] = data2.size() - i;
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records[0], "key1");
|
||||
ASSERT_EQ(written_records[1], "key2");
|
||||
|
||||
writer.writeEndOfFile();
|
||||
|
||||
std::string the_file = oss.str();
|
||||
|
||||
@@ -601,6 +601,8 @@ class PyTorchFileWriter(object):
|
||||
def write_record(self, name: str, data: bytes, size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
def get_all_written_records(self) -> List[str]: ...
|
||||
def archive_name(self) -> str: ...
|
||||
...
|
||||
|
||||
def _jit_get_inline_everything_mode() -> _bool: ...
|
||||
|
||||
@@ -877,7 +877,11 @@ void initJITBindings(PyObject* module) {
|
||||
size_t size) {
|
||||
return self.writeRecord(
|
||||
name, reinterpret_cast<const char*>(data), size);
|
||||
});
|
||||
})
|
||||
.def("archive_name", &PyTorchStreamWriter::archiveName)
|
||||
.def(
|
||||
"get_all_written_records",
|
||||
&PyTorchStreamWriter::getAllWrittenRecords);
|
||||
|
||||
py::enum_<MobileOptimizerType>(m, "MobileOptimizerType")
|
||||
.value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
|
||||
|
||||
Reference in New Issue
Block a user