Mempool use_on_oom order (#169699)

Reorder oom mitigation steps so that we reuse optional mempools before expensive releasing cached blocks.

Additionally, make sure mempools are removed from use_on_oom_pools upon deletion.
New test before fix:
```
======================================================================
ERROR: test_deleted_mempool_not_used_on_oom (__main__.TestMemPool.test_deleted_mempool_not_used_on_oom)
Test that a deleted mempool with use_on_oom=True is properly removed from use_on_oom_pools.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/danielsjohnson/oss_pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3325, in wrapper
    method(*args, **kwargs)
  File "/home/danielsjohnson/oss_pytorch/pytorch/test/test_cuda.py", line 5696, in test_deleted_mempool_not_used_on_oom
    c = torch.randn(20 * nelem_1mb, device="cuda")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "/home/danielsjohnson/oss_pytorch/pytorch/c10/cuda/CUDACachingAllocator.cpp":2700, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
    python test/test_cuda.py TestMemPool.test_deleted_mempool_not_used_on_oom
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.691s
FAILED (errors=1)
Segmentation fault (core dumped)
```

New test after fix:
```
----------------------------------------------------------------------
Ran 1 test in 0.651s
OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/169699
Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Dan Johnson
2025-12-10 22:23:23 +00:00
committed by PyTorch MergeBot
parent 832f73fe86
commit f45018788b
6 changed files with 101 additions and 39 deletions

View File

@@ -29,7 +29,7 @@ MemPool::MemPool(
device_ = c10::cuda::current_device();
CUDACachingAllocator::createOrIncrefPool(device_, id_, allocator);
if (use_on_oom) {
CUDACachingAllocator::setUseOnOOM(device_, id_);
CUDACachingAllocator::setUseOnOOM(device_, id_, true);
}
if (no_split) {
CUDACachingAllocator::setNoSplit(device_, id_);
@@ -42,6 +42,7 @@ MemPool::~MemPool() {
// However, this assertion is not true if a memory pool is shared
// with a cuda graph. That CUDAGraph will increase the use count
// until it is reset.
CUDACachingAllocator::setUseOnOOM(device_, id_, false);
CUDACachingAllocator::releasePool(device_, id_);
c10::cuda::CUDACachingAllocator::emptyCache(id_);
}

View File

@@ -146,8 +146,8 @@ public:
allocator_->createOrIncrefPool(device, mempool_id, allocator);
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->setUseOnOOM(device, mempool_id);
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id, bool use_on_oom) override {
allocator_->setUseOnOOM(device, mempool_id, use_on_oom);
}
void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) override {

View File

@@ -151,8 +151,8 @@ inline void createOrIncrefPool(
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id, bool use_on_oom) {
get()->setUseOnOOM(device, mempool_id, use_on_oom);
}
inline void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) {

View File

@@ -1434,12 +1434,17 @@ class DeviceCachingAllocator {
0.0)) {
garbage_collect_cached_blocks(context);
}
// Attempt allocate
// WARNING: alloc_block may release the allocator lock when calling
// cudaMalloc. So far this function has not modified allocator state, but
// keep in mind that any observed allocator state may change across calls
// to alloc_block since it may release the lock.
block_found = alloc_block(params, false, context, lock)
// Try to use memory pools that have opted in as overflow before
// expensive memory freeing operations.
|| try_mempool_fallback(
params, size, stream, device_id, alloc_size, stats)
// Free enough available cached blocks to satisfy alloc and retry
// alloc.
|| (release_available_cached_blocks(params, context) &&
@@ -1450,32 +1455,6 @@ class DeviceCachingAllocator {
alloc_block(params, true, context, lock));
}
// we are about to oom, try to use existing mempools as a last resort
if (!block_found && params.err == cudaErrorMemoryAllocation) {
// if already trying to use a mempool, then just oom
bool active_pool = params.pool->owner_PrivatePool;
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto tid = std::this_thread::get_id();
auto filter = [tid](cudaStream_t) {
return std::this_thread::get_id() == tid;
};
beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream);
AllocParams mempool_params(
device_id, size, stream, &mempool, alloc_size, false);
mempool_params.stat_types = get_stat_types_for_pool(mempool);
block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id);
releasePool(mempool_id);
if (block_found) {
params = mempool_params;
break;
}
}
}
}
if (!block_found) {
// For any error code other than cudaErrorMemoryAllocation,
// alloc_block should have thrown an exception already.
@@ -1605,6 +1584,39 @@ class DeviceCachingAllocator {
params, orig_size, std::move(context), split_remainder);
}
bool try_mempool_fallback(
AllocParams& params,
size_t size,
cudaStream_t stream,
c10::DeviceIndex device_idx,
size_t alloc_size,
DeviceStats& device_stats) {
bool block_found = false;
// if already trying to use a mempool, then just oom
bool active_pool = params.pool->owner_PrivatePool;
if (!active_pool) {
for (MempoolId_t mempool_id : use_on_oom_pools) {
auto tid = std::this_thread::get_id();
auto filter = [tid](cudaStream_t) {
return std::this_thread::get_id() == tid;
};
beginAllocateToPool(mempool_id, filter);
auto& mempool = get_pool(size, stream);
AllocParams mempool_params(
device_idx, size, stream, &mempool, alloc_size, false);
mempool_params.stat_types = get_stat_types_for_pool(mempool);
block_found = get_free_block(mempool_params);
endAllocateToPool(mempool_id);
releasePool(mempool_id);
if (block_found) {
params = mempool_params;
break;
}
}
}
return block_found;
}
Block* alloc_found_block(
const AllocParams& params,
size_t orig_size,
@@ -2556,10 +2568,13 @@ class DeviceCachingAllocator {
create_or_incref_pool(mempool_id, allocator);
}
void setUseOnOOM(MempoolId_t mempool_id) {
// Choose if this pool should be used as a last resort before ooming
void setUseOnOOM(MempoolId_t mempool_id, bool use_on_oom) {
std::lock_guard<std::recursive_mutex> lock(mutex);
use_on_oom_pools.insert(mempool_id);
if (use_on_oom) {
use_on_oom_pools.insert(mempool_id);
} else {
use_on_oom_pools.erase(mempool_id);
}
}
void setNoSplit(MempoolId_t mempool_id) {
@@ -4236,9 +4251,12 @@ class NativeCachingAllocator : public CUDAAllocator {
std::move(mempool_id), allocator);
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
void setUseOnOOM(
c10::DeviceIndex device,
MempoolId_t mempool_id,
bool use_on_oom) override {
assertValidDevice(device);
device_allocator[device]->setUseOnOOM(std::move(mempool_id));
device_allocator[device]->setUseOnOOM(std::move(mempool_id), use_on_oom);
}
void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) override {

View File

@@ -263,7 +263,10 @@ class CUDAAllocator : public DeviceAllocator {
" does not yet support createOrIncrefPool. "
"If you need it, please file an issue describing your use case.");
}
virtual void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
virtual void setUseOnOOM(
c10::DeviceIndex device,
MempoolId_t mempool_id,
bool use_on_oom) {
TORCH_CHECK(
false,
name(),
@@ -519,8 +522,11 @@ inline void createOrIncrefPool(
CUDAAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
inline void setUseOnOOM(
c10::DeviceIndex device,
MempoolId_t mempool_id,
bool use_on_oom) {
get()->setUseOnOOM(device, mempool_id, use_on_oom);
}
inline void setNoSplit(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setNoSplit(device, mempool_id);

View File

@@ -5721,6 +5721,43 @@ class TestMemPool(TestCase):
f"but got {blocks_no_split} vs {blocks_split}",
)
@serialTest()
def test_deleted_mempool_not_used_on_oom(self):
"""
Test that a deleted mempool with use_on_oom=True is properly removed from use_on_oom_pools.
"""
allocator, _ = self.get_dummy_allocator(check_vars=False)
nelem_1mb = 1024 * 1024 // 4
# set 40 mb total available memory
self._setup_mempool_limited_memory_test(40)
# Create many pools with use_on_oom=True, allocate memory, then delete the pools
for _ in range(10):
pool_use_on_oom = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
with torch.cuda.use_mem_pool(pool_use_on_oom):
a = torch.randn(40 * nelem_1mb, device="cuda")
del a
del pool_use_on_oom
# create new pool that we want to use_on_oom, all other pools should be deleted
# all available 40mb in use by mempool
new_pool_use_on_oom = torch.cuda.MemPool(allocator.allocator(), use_on_oom=True)
with torch.cuda.use_mem_pool(new_pool_use_on_oom):
a = torch.randn(40 * nelem_1mb, device="cuda")
del a
# allocate tensors that will fallback to use_on_oom pool since all available 40mb in use by mempool
# tensors should only use valid pool and not deleted pools
b = torch.randn(20 * nelem_1mb, device="cuda")
c = torch.randn(20 * nelem_1mb, device="cuda")
del b
del c
del new_pool_use_on_oom
self._teardown_mempool_limited_memory_test()
def test_mempool_multithread(self):
pool_ids = []