[xpu][feature] [3/6] Add snapshot support on XPU caching allocator (#169203)

# Motivation
This PR introduces memory snapshot functionality for the XPU caching allocator. Our design philosophy is to keep the implementation as simple as possible without unnecessary features. We will be able to extend the functionality in the future if real use cases arise. The `c10::xpu::XPUCachingAllocator::snapshot` API introduced in this PR will be leveraged by the Python frontend in a follow-up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169203
Approved by: https://github.com/albanD
ghstack dependencies: #168262, #169280
This commit is contained in:
Yu, Guangye
2025-12-12 16:25:36 +00:00
committed by PyTorch MergeBot
parent 5cc4ebf398
commit 118b0d9037
2 changed files with 181 additions and 0 deletions

View File

@@ -540,6 +540,16 @@ class DeviceCachingAllocator {
return subsumed_size;
}
std::vector<Block*> get_all_blocks() const {
std::vector<Block*> blocks;
blocks.insert(
blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
blocks.insert(
blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
return blocks;
}
void free_block(
Block* block,
const std::shared_ptr<GatheredContext>& context) {
@@ -1214,6 +1224,29 @@ class DeviceCachingAllocator {
}
}
std::vector<Block*> get_private_pool_head_blocks(PrivatePool* pool) const {
std::vector<Block*> blocks;
for (Block* b : active_blocks) {
if ((b->pool == &pool->small_blocks || b->pool == &pool->large_blocks) &&
b->prev == nullptr) {
blocks.push_back(b);
}
}
for (Block* b : pool->small_blocks.blocks) {
if (b->prev == nullptr) {
blocks.push_back(b);
}
}
for (Block* b : pool->large_blocks.blocks) {
if (b->prev == nullptr) {
blocks.push_back(b);
}
}
return blocks;
}
void create_or_incref_pool(
MempoolId_t mempool_id,
XPUAllocator* allocator = nullptr) {
@@ -1449,6 +1482,85 @@ class DeviceCachingAllocator {
alloc_buffer.insertEntries(te);
}
std::vector<SegmentInfo> snapshot(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<Block*> all_blocks;
if (mempool_id.first != 0 || mempool_id.second != 0) {
// If there is an active mempool, we find the corresponding PrivatePool
// in graph_pools and only return the blocks from it.
auto pool = graph_pools.find(mempool_id);
if (pool != graph_pools.end()) {
all_blocks = get_private_pool_head_blocks(pool->second.get());
}
} else {
// When snapshot is called with non-default mempool_id, we return
// all the blocks from all pools.
all_blocks = get_all_blocks();
}
size_t total_active = 0;
std::vector<SegmentInfo> result;
for (const Block* const head_block : all_blocks) {
// For expandable segments, we report one segment for each contiguous
// mapped range of memory
if (head_block->prev && head_block->prev->mapped) {
continue;
}
result.emplace_back();
SegmentInfo& segment_info = result.back();
segment_info.device = head_block->device;
segment_info.address = reinterpret_cast<size_t>(head_block->ptr);
segment_info.queue = head_block->queue;
segment_info.is_large = (!head_block->pool->is_small);
segment_info.is_expandable = head_block->expandable_segment;
segment_info.context_when_allocated =
head_block->context_when_segment_allocated;
MempoolId_t id = head_block->pool->owner_MempoolId();
if ((mempool_id.first == 0 && mempool_id.second == 0) ||
id == mempool_id) {
segment_info.owner_private_pool_id = id;
}
const Block* block = head_block;
while (block != nullptr && block->mapped) {
segment_info.blocks.emplace_back();
BlockInfo& block_info = segment_info.blocks.back();
block_info.size = block->size;
block_info.requested_size = block->requested_size;
block_info.allocated = block->allocated;
block_info.active = block->allocated || (block->event_count > 0) ||
!block->stream_uses.empty();
segment_info.total_size += block_info.size;
if (block_info.allocated) {
segment_info.allocated_size += block_info.size;
}
if (block_info.active) {
segment_info.active_size += block_info.size;
segment_info.requested_size += block_info.requested_size;
}
block_info.context_when_allocated = block->context_when_allocated;
block = block->next;
}
total_active += segment_info.active_size;
}
std::sort(
result.begin(),
result.end(),
[](const SegmentInfo& a, const SegmentInfo& b) {
return a.address < b.address;
});
record_trace(
TraceEntry::SNAPSHOT, 0, total_active, nullptr, 0, mempool_id, nullptr);
return result;
}
std::vector<TraceEntry> trace(
const std::function<time_t(approx_time_t)>& tsc_to_us) const {
std::lock_guard<std::recursive_mutex> lock(mutex);
@@ -1579,6 +1691,7 @@ class NativeCachingAllocator : public XPUAllocator {
private:
alignas(hardware_destructive_interference_size) std::mutex mutex;
ska::flat_hash_map<void*, Block*> allocated_blocks;
c10::ApproximateClockToUnixTimeConverter clock_converter;
void add_allocated_block(Block* block) {
std::lock_guard<std::mutex> lock(mutex);
@@ -1734,6 +1847,30 @@ class NativeCachingAllocator : public XPUAllocator {
device_allocators[device]->resetAccumulatedStats();
}
SnapshotInfo snapshot(MempoolId_t mempool_id) {
// Set-up converter to convert timestamps from tsc to microseconds.
auto tsc_to_ns = clock_converter.makeConverter();
auto tsc_to_us = [=](approx_time_t t_approx) {
return tsc_to_ns(t_approx) / 1000;
};
SnapshotInfo result;
// Get the device_traces' TraceEntry lists.
for (auto& da : device_allocators) {
result.device_traces.emplace_back(da->trace(tsc_to_us));
auto snap = da->snapshot(mempool_id);
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
}
auto& md = result.config_metadata;
md.expandable_segments =
AcceleratorAllocatorConfig::use_expandable_segments();
md.last_allocator_settings =
AcceleratorAllocatorConfig::last_allocator_settings();
return result;
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
assertValidDevice(dev);
assertValidDevice(dev_to_access);
@@ -1850,6 +1987,10 @@ void recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}
SnapshotInfo snapshot(MempoolId_t mempool_id) {
return native_allocator.snapshot(mempool_id);
}
void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,

View File

@@ -25,6 +25,33 @@ enum struct RecordContext {
ALL = 3, // additionally record stacks for when something is freed
};
// Struct containing info of an allocation block
struct BlockInfo {
size_t size = 0;
size_t requested_size = 0;
int32_t gc_counter = 0;
bool allocated = false;
bool active = false;
std::shared_ptr<GatheredContext> context_when_allocated;
};
// Struct containing info of a memory segment (i.e. one contiguous device memory
// allocation).
struct SegmentInfo {
c10::DeviceIndex device = 0;
size_t address = 0;
size_t total_size = 0;
size_t requested_size = 0; // unrounded, actually requested size
size_t allocated_size = 0;
size_t active_size = 0;
sycl::queue* queue = nullptr;
bool is_large = false;
bool is_expandable = false;
MempoolId_t owner_private_pool_id = {0, 0};
std::vector<BlockInfo> blocks;
std::shared_ptr<GatheredContext> context_when_allocated;
};
union trace_time_ {
time_t t_;
approx_time_t approx_t_;
@@ -73,6 +100,17 @@ struct TraceEntry {
trace_time_ time_{};
};
struct AllocatorConfigInfo {
bool expandable_segments;
std::string last_allocator_settings;
};
struct SnapshotInfo {
std::vector<SegmentInfo> segments;
std::vector<std::vector<TraceEntry>> device_traces;
AllocatorConfigInfo config_metadata;
};
inline XPUAllocator* get() {
return allocator.load();
}
@@ -125,6 +163,8 @@ C10_XPU_API void recordHistory(
RecordContext when,
bool clearHistory);
C10_XPU_API SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0});
C10_XPU_API void createOrIncrefPool(
c10::DeviceIndex device,
c10::MempoolId_t mempool_id,