mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
The Nested Pool (#168382)
This PR fixes issue #161193 by simply reversing the iteration order over captures_underway. After discussing with @galv, we decided to land this minimal fix first to unblock nested MemPool usage. Long-term, the underlying infrastructure (e.g., captures_underway) still needs refactoring to clearly define the interaction between graph capture, MemPools, and threads. That broader cleanup will be addressed in #168137. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168382 Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/galv
This commit is contained in:
committed by
PyTorch MergeBot
parent
87329491c8
commit
d038b0130e
@@ -1838,9 +1838,11 @@ class DeviceCachingAllocator {
|
||||
if (graph_reuse_context.find(info.capture_id) ==
|
||||
graph_reuse_context.end()) {
|
||||
bool found = false;
|
||||
for (auto& entry : captures_underway) {
|
||||
if (entry.second(stream)) {
|
||||
auto graph_pool = graph_pools.find(entry.first);
|
||||
// Use the reverse iterator to search captures_underway in LIFO order.
|
||||
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
|
||||
++it) {
|
||||
if (it->second(stream)) {
|
||||
auto graph_pool = graph_pools.find(it->first);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
graph_pool != graph_pools.end(),
|
||||
"Could not find graph pool for capture.");
|
||||
@@ -2530,10 +2532,10 @@ class DeviceCachingAllocator {
|
||||
std::function<bool(cudaStream_t)> filter) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex);
|
||||
create_or_incref_pool(mempool_id);
|
||||
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
|
||||
++it2) {
|
||||
for (auto it = captures_underway.begin(); it != captures_underway.end();
|
||||
++it) {
|
||||
TORCH_CHECK(
|
||||
it2->first != mempool_id,
|
||||
it->first != mempool_id,
|
||||
"beginAllocateToPool: already recording to mempool_id");
|
||||
}
|
||||
captures_underway.emplace_back(mempool_id, std::move(filter));
|
||||
@@ -2962,9 +2964,11 @@ class DeviceCachingAllocator {
|
||||
// a capture, so it's usually 0, and we can short-circuit
|
||||
// cudaStreamCaptureStatus (which does a TLS lookup).
|
||||
if (C10_UNLIKELY(!captures_underway.empty())) {
|
||||
for (auto& entry : captures_underway) {
|
||||
if (entry.second(stream)) {
|
||||
auto it1 = graph_pools.find(entry.first);
|
||||
// Use the reverse iterator to search captures_underway in LIFO order.
|
||||
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
|
||||
++it) {
|
||||
if (it->second(stream)) {
|
||||
auto it1 = graph_pools.find(it->first);
|
||||
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
|
||||
if (size <= kSmallSize) {
|
||||
return it1->second->small_blocks;
|
||||
|
||||
@@ -5710,6 +5710,37 @@ class TestMemPool(TestCase):
|
||||
s = p.snapshot()
|
||||
self.assertEqual(len(s), 1, "Expected to have a single segment")
|
||||
|
||||
@serialTest()
|
||||
def test_nested_mempool(self):
|
||||
torch.cuda.empty_cache()
|
||||
pool1 = torch.cuda.MemPool()
|
||||
pool2 = torch.cuda.MemPool()
|
||||
pool3 = torch.cuda.MemPool()
|
||||
|
||||
data = []
|
||||
nelem_1mb = 1024 * 1024 // 4
|
||||
|
||||
def allocate_data():
|
||||
x = torch.empty(nelem_1mb * 20, device="cuda")
|
||||
data.append(x)
|
||||
|
||||
with torch.cuda.use_mem_pool(pool1):
|
||||
allocate_data()
|
||||
with torch.cuda.use_mem_pool(pool2):
|
||||
allocate_data()
|
||||
with torch.cuda.use_mem_pool(pool3):
|
||||
allocate_data()
|
||||
allocate_data()
|
||||
allocate_data()
|
||||
|
||||
pool1_segments = torch.cuda.memory.memory_snapshot(pool1.id)
|
||||
pool2_segments = torch.cuda.memory.memory_snapshot(pool2.id)
|
||||
pool3_segments = torch.cuda.memory.memory_snapshot(pool3.id)
|
||||
|
||||
self.assertEqual(len(pool1_segments), 2)
|
||||
self.assertEqual(len(pool2_segments), 2)
|
||||
self.assertEqual(len(pool3_segments), 1)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user