The Nested Pool (#168382)

This PR fixes issue #161193 by simply reversing the iteration order over captures_underway.
After discussing with @galv, we decided to land this minimal fix first to unblock nested MemPool usage.

Long-term, the underlying infrastructure (e.g., captures_underway) still needs refactoring to clearly define the interaction between graph capture, MemPools, and threads. That broader cleanup will be addressed in #168137.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168382
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/galv
This commit is contained in:
Frank Lin
2025-12-03 18:37:58 +00:00
committed by PyTorch MergeBot
parent 87329491c8
commit d038b0130e
2 changed files with 44 additions and 9 deletions

View File

@@ -1838,9 +1838,11 @@ class DeviceCachingAllocator {
if (graph_reuse_context.find(info.capture_id) ==
graph_reuse_context.end()) {
bool found = false;
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto graph_pool = graph_pools.find(entry.first);
// Use the reverse iterator to search captures_underway in LIFO order.
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
++it) {
if (it->second(stream)) {
auto graph_pool = graph_pools.find(it->first);
TORCH_INTERNAL_ASSERT(
graph_pool != graph_pools.end(),
"Could not find graph pool for capture.");
@@ -2530,10 +2532,10 @@ class DeviceCachingAllocator {
std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex);
create_or_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) {
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
TORCH_CHECK(
it2->first != mempool_id,
it->first != mempool_id,
"beginAllocateToPool: already recording to mempool_id");
}
captures_underway.emplace_back(mempool_id, std::move(filter));
@@ -2962,9 +2964,11 @@ class DeviceCachingAllocator {
// a capture, so it's usually 0, and we can short-circuit
// cudaStreamCaptureStatus (which does a TLS lookup).
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto it1 = graph_pools.find(entry.first);
// Use the reverse iterator to search captures_underway in LIFO order.
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
++it) {
if (it->second(stream)) {
auto it1 = graph_pools.find(it->first);
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;

View File

@@ -5710,6 +5710,37 @@ class TestMemPool(TestCase):
s = p.snapshot()
self.assertEqual(len(s), 1, "Expected to have a single segment")
@serialTest()
def test_nested_mempool(self):
torch.cuda.empty_cache()
pool1 = torch.cuda.MemPool()
pool2 = torch.cuda.MemPool()
pool3 = torch.cuda.MemPool()
data = []
nelem_1mb = 1024 * 1024 // 4
def allocate_data():
x = torch.empty(nelem_1mb * 20, device="cuda")
data.append(x)
with torch.cuda.use_mem_pool(pool1):
allocate_data()
with torch.cuda.use_mem_pool(pool2):
allocate_data()
with torch.cuda.use_mem_pool(pool3):
allocate_data()
allocate_data()
allocate_data()
pool1_segments = torch.cuda.memory.memory_snapshot(pool1.id)
pool2_segments = torch.cuda.memory.memory_snapshot(pool2.id)
pool3_segments = torch.cuda.memory.memory_snapshot(pool3.id)
self.assertEqual(len(pool1_segments), 2)
self.assertEqual(len(pool2_segments), 2)
self.assertEqual(len(pool3_segments), 1)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)