diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 69a8432d40e..fbf96a1db23 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -346,6 +347,88 @@ std::tuple PyTorchStreamReader::getRecord(const std::string return std::make_tuple(std::move(retval), stat.m_uncomp_size); } +size_t +PyTorchStreamReader::getRecordMultiReaders(const std::string& name, + std::vector>& additionalReaders, + void *dst, size_t n){ + + size_t nthread = additionalReaders.size()+1; + size_t recordOff = getRecordOffset(name); + std::vector loaderThreads; + size_t perThreadSize = (n+nthread-1)/nthread; + std::vector readSizes(nthread, 0); + std::lock_guard guard(reader_lock_); + for(size_t i = 0; i < nthread ; i++){ + loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{ + size_t startPos = i*perThreadSize; + size_t endPos = std::min((i+1)*perThreadSize,n); + if (startPos < endPos){ + size_t threadReadSize = endPos - startPos; + size_t size = 0; + if (i==0){ + size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize); + }else{ + auto reader = additionalReaders[i-1]; + size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize); + } + readSizes[i] = size; + LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] " + << "from " << name << " of size " << n; + TORCH_CHECK( + threadReadSize == size, + "record size ", + threadReadSize, + " mismatch with read size ", + size); + } + }); + } + + for (auto& thread : loaderThreads) { + thread.join(); + } + loaderThreads.clear(); + + size_t total_read_n = 0; + for (auto& r : readSizes){ + total_read_n += r; + } + + TORCH_CHECK( + n == total_read_n, + "Multi reader total read size ", + total_read_n, + " mismatch with dst size ", + n); + + return total_read_n; +} + +// read record with multi clients +std::tuple +PyTorchStreamReader::getRecord(const std::string& name, + std::vector>& additionalReaders) { + + if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) { + at::DataPtr retval; + return std::make_tuple(std::move(retval), 0); + } + size_t key = getRecordID(name); + mz_zip_archive_file_stat stat; + mz_zip_reader_file_stat(ar_.get(), key, &stat); + auto n = stat.m_uncomp_size; + valid("retrieving file meta-data for ", name.c_str()); + if(additionalReaders.empty() || n < additional_reader_size_threshold_){ + // No additional readers or record too small, use single threaded version + return getRecord(name); + } + + at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size); + void* dst = retval.get(); + PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n); + return std::make_tuple(std::move(retval), stat.m_uncomp_size); +} + // inplace memory writing size_t PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) { @@ -369,6 +452,34 @@ PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) { return stat.m_uncomp_size; } + +// inplace memory writing, in-tensor multi-threads, can be used for large tensor. +size_t +PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n, + std::vector>& additionalReaders) { + if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) { + return 0; + } + size_t key = getRecordID(name); + mz_zip_archive_file_stat stat; + mz_zip_reader_file_stat(ar_.get(), key, &stat); + TORCH_CHECK( + n == stat.m_uncomp_size, + "record size ", + stat.m_uncomp_size, + " mismatch with dst size ", + n); + valid("retrieving file meta-data for ", name.c_str()); + + if(additionalReaders.empty() || n < additional_reader_size_threshold_){ + // No additional readers, use single threaded version + return getRecord(name, dst, n); + } + + PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n); + return stat.m_uncomp_size; +} + size_t PyTorchStreamReader::getRecord( const std::string& name, void* dst, diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index e7981e74bcb..aa0cb8e0432 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -16,6 +16,7 @@ #include "caffe2/serialize/read_adapter_interface.h" #include "caffe2/serialize/versions.h" + extern "C" { typedef struct mz_zip_archive mz_zip_archive; } @@ -126,8 +127,15 @@ class TORCH_API PyTorchStreamReader final { // return dataptr, size std::tuple getRecord(const std::string& name); + // multi-thread getRecord + std::tuple getRecord(const std::string& name, std::vector>& additionalReaders); // inplace memory writing size_t getRecord(const std::string& name, void* dst, size_t n); + // inplace memory writing, multi-threads. + // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader + // This approach can be used for reading large tensors. + size_t getRecord(const std::string& name, void* dst, size_t n, + std::vector>& additionalReaders); size_t getRecord( const std::string& name, void* dst, @@ -136,6 +144,20 @@ class TORCH_API PyTorchStreamReader final { void* buf, const std::function& memcpy_func = nullptr); + // Concurrent reading records with multiple readers. + // additionalReaders are additional clients to access the underlying record at different offsets + // and write to different trunks of buffers. + // If the overall size of the tensor is 10, and size of additionalReader is 2. + // The default thread will read [0,4), the additional reader will read [4,8). + // The default reader will read [8,10). + // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8), + // the additional reader will write to buffer[8,10). + // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader + // This approach can be used for reading large tensors. + size_t getRecordMultiReaders(const std::string& name, + std::vector>& additionalReaders, + void *dst, size_t n); + size_t getRecordSize(const std::string& name); size_t getRecordOffset(const std::string& name); @@ -158,7 +180,9 @@ class TORCH_API PyTorchStreamReader final { void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { load_debug_symbol_ = should_load_debug_symbol; } - + void setAdditionalReaderSizeThreshold(const size_t& size){ + additional_reader_size_threshold_ = size; + } private: void init(); size_t read(uint64_t pos, char* buf, size_t n); @@ -175,6 +199,7 @@ class TORCH_API PyTorchStreamReader final { std::mutex reader_lock_; bool load_debug_symbol_ = true; std::string serialization_id_; + size_t additional_reader_size_threshold_; }; class TORCH_API PyTorchStreamWriter final { diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 821de91e938..b2313d39e6a 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -105,6 +105,86 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { remove(file_name); } +TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) { + + std::ostringstream oss; + // write records through writers + PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { + oss.write(static_cast(b), n); + return oss ? n : 0; + }); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) + std::array data1; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers) + std::array data2; + for (auto i : c10::irange(data1.size())) { + data1[i] = data1.size() - i; + } + writer.writeRecord("key1", data1.data(), data1.size()); + + for (auto i : c10::irange(data2.size())) { + data2[i] = data2.size() - i; + } + writer.writeRecord("key2", data2.data(), data2.size()); + + const std::unordered_set& written_records = + writer.getAllWrittenRecords(); + ASSERT_EQ(written_records.size(), 2); + ASSERT_EQ(written_records.count("key1"), 1); + ASSERT_EQ(written_records.count("key2"), 1); + + writer.writeEndOfFile(); + ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1); + + std::string the_file = oss.str(); + const char* file_name = "output.zip"; + std::ofstream foo(file_name); + foo.write(the_file.c_str(), the_file.size()); + foo.close(); + + // read records through pytorchStreamReader + std::istringstream iss(the_file); + PyTorchStreamReader reader(&iss); + reader.setAdditionalReaderSizeThreshold(0); + // before testing, sanity check + int64_t size1, size2, ret; + at::DataPtr data_ptr; + std::tie(data_ptr, size1) = reader.getRecord("key1"); + std::tie(data_ptr, size2) = reader.getRecord("key2"); + + // Test getRecord(name, additional_readers) + std::vector> additionalReader; + for(int i=0; i<10; ++i){ + // Test various sized additional readers. + std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader); + ASSERT_EQ(ret, size1); + ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0); + + std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader); + ASSERT_EQ(ret, size2); + ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0); + } + + // Inplace multi-threading getRecord(name, dst, n, additional_readers) test + additionalReader.clear(); + std::vector dst1(size1), dst2(size2); + for(int i=0; i<10; ++i){ + // Test various sizes of read threads + additionalReader.push_back(std::make_unique(&iss)); + + ret = reader.getRecord("key1", dst1.data(), size1, additionalReader); + ASSERT_EQ(ret, size1); + ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0); + + ret = reader.getRecord("key2", dst2.data(), size2, additionalReader); + ASSERT_EQ(ret, size2); + ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0); + } + // clean up + remove(file_name); +} + TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { std::ostringstream oss; // write records through writers