mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Reorder oom mitigation steps so that we reuse optional mempools before expensive releasing cached blocks.
Additionally, make sure mempools are removed from use_on_oom_pools upon deletion.
New test before fix:
```
======================================================================
ERROR: test_deleted_mempool_not_used_on_oom (__main__.TestMemPool.test_deleted_mempool_not_used_on_oom)
Test that a deleted mempool with use_on_oom=True is properly removed from use_on_oom_pools.
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/danielsjohnson/oss_pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3325, in wrapper
method(*args, **kwargs)
File "/home/danielsjohnson/oss_pytorch/pytorch/test/test_cuda.py", line 5696, in test_deleted_mempool_not_used_on_oom
c = torch.randn(20 * nelem_1mb, device="cuda")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: it->second->use_count > 0 INTERNAL ASSERT FAILED at "/home/danielsjohnson/oss_pytorch/pytorch/c10/cuda/CUDACachingAllocator.cpp":2700, please report a bug to PyTorch.
To execute this test, run the following from the base repo dir:
python test/test_cuda.py TestMemPool.test_deleted_mempool_not_used_on_oom
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.691s
FAILED (errors=1)
Segmentation fault (core dumped)
```
New test after fix:
```
----------------------------------------------------------------------
Ran 1 test in 0.651s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169699
Approved by: https://github.com/ngimel, https://github.com/eqy