mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Add per_process_memory_fraction to PYTORCH_CUDA_ALLOC_CONF (#161035)
torch.cuda.memory.set_per_process_memory_fraction allows setting an upper bound on how much device memory is allocated. This PR exposes this setting to an environment variable. For example, PYTORCH_CUDA_ALLOC_CONF="per_process_memory_fraction:0.5" will limit the device memory to half of the available memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161035 Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
da2eb31b82
commit
7b055a0103
@@ -106,6 +106,9 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
|
||||
} else if (key == "graph_capture_record_stream_reuse") {
|
||||
i = parseGraphCaptureRecordStreamReuse(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (key == "per_process_memory_fraction") {
|
||||
i = parsePerProcessMemoryFraction(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else {
|
||||
const auto& keys =
|
||||
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
|
||||
@@ -146,6 +149,18 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse(
|
||||
return i;
|
||||
}
|
||||
|
||||
double CUDAAllocatorConfig::parsePerProcessMemoryFraction(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
double val_env = tokenizer.toDouble(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
val_env >= 0.0 && val_env <= 1.0,
|
||||
"per_process_memory_fraction is invalid, set it in [0.0, 1.0]");
|
||||
m_per_process_memory_fraction = val_env;
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
|
||||
@@ -61,6 +61,10 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_graph_capture_record_stream_reuse;
|
||||
}
|
||||
|
||||
static double per_process_memory_fraction() {
|
||||
return instance().m_per_process_memory_fraction;
|
||||
}
|
||||
|
||||
/** Pinned memory allocator settings */
|
||||
static bool pinned_use_cuda_host_register() {
|
||||
return instance().m_pinned_use_cuda_host_register;
|
||||
@@ -152,7 +156,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
"pinned_use_hip_host_register",
|
||||
"graph_capture_record_stream_reuse",
|
||||
"pinned_reserve_segment_size_mb",
|
||||
"pinned_num_register_threads"};
|
||||
"pinned_num_register_threads",
|
||||
"per_process_memory_fraction"};
|
||||
return keys;
|
||||
}
|
||||
|
||||
@@ -177,6 +182,9 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
size_t parseGraphCaptureRecordStreamReuse(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
double parsePerProcessMemoryFraction(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
|
||||
std::atomic<size_t> m_pinned_num_register_threads{1};
|
||||
std::atomic<size_t> m_pinned_reserve_segment_size_mb{0};
|
||||
@@ -189,6 +197,7 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc{false};
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register{false};
|
||||
std::atomic<bool> m_graph_capture_record_stream_reuse{false};
|
||||
std::atomic<double> m_per_process_memory_fraction{1.0};
|
||||
};
|
||||
|
||||
// Keep this for backwards compatibility
|
||||
|
||||
@@ -1100,7 +1100,7 @@ class RingBuffer {
|
||||
} // anonymous namespace
|
||||
} // namespace Native
|
||||
|
||||
static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
|
||||
static std::string reportProcessMemoryInfo(const cudaDeviceProp& prop) {
|
||||
#ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
void* nvml_handle = DriverAPI::get_nvml_handle();
|
||||
if (!nvml_handle) {
|
||||
@@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
|
||||
return true;
|
||||
}();
|
||||
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
|
||||
// NOLINTNEXTLINE(*-c-arrays)
|
||||
char pci_id[80];
|
||||
snprintf(
|
||||
@@ -1215,14 +1212,16 @@ class DeviceCachingAllocator {
|
||||
// record used memory.
|
||||
size_t total_allocated_memory = 0;
|
||||
|
||||
size_t allowed_memory_maximum = 0;
|
||||
cudaDeviceProp device_prop;
|
||||
|
||||
// maximum amount of memory that device is allowed to
|
||||
// allocate. This is set iff memory fraction is less than 1
|
||||
std::optional<size_t> allowed_memory_maximum{std::nullopt};
|
||||
|
||||
// all live expandable segments
|
||||
std::vector<ExpandableSegment*> expandable_segments_;
|
||||
std::vector<c10::DeviceIndex> devices_with_peer_access_;
|
||||
|
||||
bool set_fraction = false;
|
||||
|
||||
bool record_history = false;
|
||||
|
||||
std::atomic<CreateContextFn> context_recorder_;
|
||||
@@ -1264,6 +1263,9 @@ class DeviceCachingAllocator {
|
||||
: device_id(id),
|
||||
large_blocks(/*small=*/false),
|
||||
small_blocks(/*small=*/true) {
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id));
|
||||
|
||||
setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction());
|
||||
stats.max_split_size =
|
||||
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
|
||||
context_recorder_.store(nullptr);
|
||||
@@ -1399,7 +1401,7 @@ class DeviceCachingAllocator {
|
||||
if (!block_found) {
|
||||
// Do garbage collection if the flag is set.
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
allowed_memory_maximum.has_value() &&
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() >
|
||||
0.0)) {
|
||||
garbage_collect_cached_blocks(context);
|
||||
@@ -1456,11 +1458,12 @@ class DeviceCachingAllocator {
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
std::string allowed_info;
|
||||
|
||||
if (set_fraction) {
|
||||
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
|
||||
if (allowed_memory_maximum.has_value()) {
|
||||
allowed_info =
|
||||
format_size(allowed_memory_maximum.value()) + " allowed; ";
|
||||
}
|
||||
|
||||
std::string proc_info = reportProcessMemoryInfo(device_id);
|
||||
std::string proc_info = reportProcessMemoryInfo(device_prop);
|
||||
|
||||
record_trace(
|
||||
TraceEntry::OOM,
|
||||
@@ -1518,7 +1521,7 @@ class DeviceCachingAllocator {
|
||||
for (const auto& obs : observers_local) {
|
||||
obs(device_id,
|
||||
alloc_size,
|
||||
set_fraction ? allowed_memory_maximum : device_total,
|
||||
allowed_memory_maximum.value_or(device_total),
|
||||
device_free);
|
||||
}
|
||||
|
||||
@@ -2015,25 +2018,26 @@ class DeviceCachingAllocator {
|
||||
|
||||
/** get memory fraction limiting maximum allocated memory **/
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
if (!allowed_memory_maximum.has_value()) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_total);
|
||||
return static_cast<double>(allowed_memory_maximum.value()) /
|
||||
static_cast<double>(device_prop.totalGlobalMem);
|
||||
}
|
||||
|
||||
/** set memory fraction to limit maximum allocated memory **/
|
||||
void setMemoryFraction(double fraction) {
|
||||
size_t device_free = 0;
|
||||
size_t device_total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
|
||||
allowed_memory_maximum =
|
||||
static_cast<size_t>(fraction * static_cast<double>(device_total));
|
||||
set_fraction = true;
|
||||
TORCH_CHECK(
|
||||
0 <= fraction && fraction <= 1,
|
||||
"invalid fraction:",
|
||||
fraction,
|
||||
". Please set within [0, 1].");
|
||||
allowed_memory_maximum = std::nullopt;
|
||||
if (fraction < 1.0) {
|
||||
allowed_memory_maximum = static_cast<size_t>(
|
||||
fraction * static_cast<double>(device_prop.totalGlobalMem));
|
||||
}
|
||||
}
|
||||
|
||||
/** get expandable segment size for all the streams on device **/
|
||||
@@ -3010,7 +3014,7 @@ class DeviceCachingAllocator {
|
||||
BlockPool& pool = *p.pool;
|
||||
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
allowed_memory_maximum.has_value() &&
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
// Track block reuse interval only when garbage collection is enabled.
|
||||
++pool.get_free_blocks_call_count;
|
||||
@@ -3083,7 +3087,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
size_t gc_threshold = static_cast<size_t>(
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() *
|
||||
static_cast<double>(allowed_memory_maximum));
|
||||
static_cast<double>(allowed_memory_maximum.value()));
|
||||
// No need to trigger GC yet
|
||||
if (total_allocated_memory <= gc_threshold) {
|
||||
return;
|
||||
@@ -3161,8 +3165,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool active_pool =
|
||||
p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator();
|
||||
if (set_fraction &&
|
||||
total_allocated_memory + size > allowed_memory_maximum) {
|
||||
if (allowed_memory_maximum.has_value() &&
|
||||
total_allocated_memory + size > allowed_memory_maximum.value()) {
|
||||
p.err = cudaErrorMemoryAllocation;
|
||||
return false;
|
||||
// Temporarily disable checkpointing & cudagraphs internally
|
||||
@@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
"Allocator not initialized for device ",
|
||||
device,
|
||||
": did you call init?");
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
|
||||
return device_allocator[device]->getMemoryFraction();
|
||||
}
|
||||
|
||||
@@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
"Allocator not initialized for device ",
|
||||
device,
|
||||
": did you call init?");
|
||||
TORCH_CHECK(
|
||||
0 <= fraction && fraction <= 1,
|
||||
"invalid fraction:",
|
||||
fraction,
|
||||
". Please set within [0, 1].");
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
|
||||
device_allocator[device]->setMemoryFraction(fraction);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
@@ -427,7 +427,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
|
||||
// on the current device each later call sees.
|
||||
void init(int dev_count) override {
|
||||
static bool called = [](int dev_count) {
|
||||
;
|
||||
// Are there external guarantees init will be called before
|
||||
// any of the allocator's other functions?
|
||||
// std::lock_guard<std::mutex> lk(general_mutex);
|
||||
|
||||
@@ -619,6 +619,10 @@ Available options:
|
||||
and reallocate buffers across multiple streams, especially when the capture DAG frequently
|
||||
reaches joined frontiers.
|
||||
|
||||
* ``per_process_memory_fraction`` option limits the amount of memory that can be allocated
|
||||
on all the CUDA devices to a specified fraction of the available memory. This is a value
|
||||
between 0 and 1. Attempting to allocate more memory will raise an out of memory error.
|
||||
|
||||
.. note::
|
||||
|
||||
Some stats reported by the
|
||||
|
||||
@@ -4626,6 +4626,52 @@ print(torch.cuda.get_allocator_backend())
|
||||
rc = check_output(test_script)
|
||||
self.assertEqual(rc, "cudaMallocAsync")
|
||||
|
||||
def test_allocator_memory_fraction_setting(self):
|
||||
def make_env(fraction):
|
||||
env = os.environ.copy()
|
||||
var = "PYTORCH_CUDA_ALLOC_CONF"
|
||||
key = "per_process_memory_fraction"
|
||||
value = [
|
||||
x
|
||||
for x in env.get(var, "").split(",")
|
||||
if len(x) > 0 and not x.startswith(f"{key}:")
|
||||
]
|
||||
value.append(f"{key}:{fraction}")
|
||||
env[var] = ",".join(value)
|
||||
return env
|
||||
|
||||
def run_test(value):
|
||||
test_script = """\
|
||||
import os
|
||||
import torch
|
||||
device = torch._C._cuda_getDevice()
|
||||
value = torch.cuda.memory.get_per_process_memory_fraction(device)
|
||||
print(value, end="")
|
||||
"""
|
||||
return subprocess.run(
|
||||
[sys.executable, "-c", test_script],
|
||||
env=make_env(value),
|
||||
text=True,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
self.assertEqual(run_test(0.0).stdout, "0.0")
|
||||
self.assertEqual(run_test(0.5).stdout, "0.5")
|
||||
self.assertEqual(run_test(1.0).stdout, "1.0")
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError) as e:
|
||||
run_test(-0.1)
|
||||
assert "per_process_memory_fraction is invalid" in e.exception.stderr, (
|
||||
e.exception.stderr
|
||||
)
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError) as e:
|
||||
run_test(1.1)
|
||||
assert "per_process_memory_fraction is invalid" in e.exception.stderr, (
|
||||
e.exception.stderr
|
||||
)
|
||||
|
||||
def test_cachingAllocator_raw_alloc(self):
|
||||
# Test that raw_alloc respects the setting that
|
||||
# activates/deactivates the caching allocator
|
||||
|
||||
Reference in New Issue
Block a user