The Nested Pool (#168382)

This PR fixes issue #161193 by simply reversing the iteration order over captures_underway.
After discussing with @galv, we decided to land this minimal fix first to unblock nested MemPool usage.

Long-term, the underlying infrastructure (e.g., captures_underway) still needs refactoring to clearly define the interaction between graph capture, MemPools, and threads. That broader cleanup will be addressed in #168137.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/168382
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/galv
This commit is contained in:
Frank Lin
2025-12-03 18:37:58 +00:00
committed by PyTorch MergeBot
parent 87329491c8
commit d038b0130e
2 changed files with 44 additions and 9 deletions

View File

@@ -1838,9 +1838,11 @@ class DeviceCachingAllocator {
if (graph_reuse_context.find(info.capture_id) ==
graph_reuse_context.end()) {
bool found = false;
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto graph_pool = graph_pools.find(entry.first);
// Use the reverse iterator to search captures_underway in LIFO order.
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
++it) {
if (it->second(stream)) {
auto graph_pool = graph_pools.find(it->first);
TORCH_INTERNAL_ASSERT(
graph_pool != graph_pools.end(),
"Could not find graph pool for capture.");
@@ -2530,10 +2532,10 @@ class DeviceCachingAllocator {
std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex);
create_or_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) {
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
TORCH_CHECK(
it2->first != mempool_id,
it->first != mempool_id,
"beginAllocateToPool: already recording to mempool_id");
}
captures_underway.emplace_back(mempool_id, std::move(filter));
@@ -2962,9 +2964,11 @@ class DeviceCachingAllocator {
// a capture, so it's usually 0, and we can short-circuit
// cudaStreamCaptureStatus (which does a TLS lookup).
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto it1 = graph_pools.find(entry.first);
// Use the reverse iterator to search captures_underway in LIFO order.
for (auto it = captures_underway.rbegin(); it != captures_underway.rend();
++it) {
if (it->second(stream)) {
auto it1 = graph_pools.find(it->first);
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;