[xpu][feature] Add skip actions support to filter out memory trace (#170760)

# Motivation
 This PR intends to introduce a flag to skip specific events in the memory snapshot to reduce trace file size and improve HTML viewer usability. This PR does the same thing as in https://github.com/pytorch/pytorch/issues/168183.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170760
Approved by: https://github.com/EikanWang
ghstack dependencies: #169812
This commit is contained in:
Yu, Guangye
2025-12-22 19:38:55 +00:00
committed by PyTorch MergeBot
parent e4094046bc
commit ec969a2278
2 changed files with 41 additions and 6 deletions

View File

@@ -501,6 +501,7 @@ class DeviceCachingAllocator {
std::atomic<CreateContextFn> context_recorder_;
RecordContext record_context_ = RecordContext::NEVER;
RingBuffer<TraceEntry> alloc_buffer;
std::unordered_set<TraceEntry::Action> skip_actions_list;
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
captures_underway;
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
@@ -1544,6 +1545,9 @@ class DeviceCachingAllocator {
std::shared_ptr<GatheredContext> context) {
if (!record_history)
return;
bool should_skip = skip_actions_list.count(action) > 0;
if (should_skip)
return;
TraceEntry te(
action,
device,
@@ -1653,10 +1657,32 @@ class DeviceCachingAllocator {
CreateContextFn context_recorder,
size_t alloc_buffer_max_entries,
RecordContext when,
bool clearHistory) {
bool clearHistory,
const std::vector<std::string>& skip_actions) {
std::unique_lock<std::recursive_mutex> lock(mutex);
TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
record_history = enabled;
static const std::unordered_map<std::string, TraceEntry::Action>
kActionMap = {
{"alloc", TraceEntry::Action::ALLOC},
{"free_requested", TraceEntry::Action::FREE_REQUESTED},
{"free_completed", TraceEntry::Action::FREE_COMPLETED},
{"segment_alloc", TraceEntry::Action::SEGMENT_ALLOC},
{"segment_free", TraceEntry::Action::SEGMENT_FREE},
{"segment_map", TraceEntry::Action::SEGMENT_MAP},
{"segment_unmap", TraceEntry::Action::SEGMENT_UNMAP},
{"snapshot", TraceEntry::Action::SNAPSHOT},
{"oom", TraceEntry::Action::OOM},
};
skip_actions_list.clear();
for (const auto& action_str : skip_actions) {
auto it = kActionMap.find(action_str);
TORCH_CHECK(it != kActionMap.end(), "Unknown skip action: ", action_str);
skip_actions_list.insert(it->second);
}
context_recorder_.store(record_history ? context_recorder : nullptr);
alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
record_context_ = enabled ? when : RecordContext::NEVER;
@@ -1977,14 +2003,16 @@ class NativeCachingAllocator : public XPUAllocator {
CreateContextFn context_recorder,
size_t alloc_buffer_max_entries,
RecordContext when,
bool clearHistory) {
bool clearHistory,
const std::vector<std::string>& skip_actions) {
for (auto& allocator : device_allocators) {
allocator->recordHistory(
enabled,
context_recorder,
alloc_buffer_max_entries,
when,
clearHistory);
clearHistory,
skip_actions);
}
}
@@ -2056,9 +2084,15 @@ void recordHistory(
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
RecordContext when,
bool clearHistory) {
bool clearHistory,
const std::vector<std::string>& skip_actions) {
native_allocator.recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
enabled,
context_recorder,
alloc_trace_max_entries,
when,
clearHistory,
skip_actions);
}
SnapshotInfo snapshot(MempoolId_t mempool_id) {

View File

@@ -161,7 +161,8 @@ C10_XPU_API void recordHistory(
CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
RecordContext when,
bool clearHistory);
bool clearHistory,
const std::vector<std::string>& skip_actions);
C10_XPU_API SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0});