mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[xpu][feature] Add skip actions support to filter out memory trace (#170760)
# Motivation This PR intends to introduce a flag to skip specific events in the memory snapshot to reduce trace file size and improve HTML viewer usability. This PR does the same thing as in https://github.com/pytorch/pytorch/issues/168183. Pull Request resolved: https://github.com/pytorch/pytorch/pull/170760 Approved by: https://github.com/EikanWang ghstack dependencies: #169812
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4094046bc
commit
ec969a2278
@@ -501,6 +501,7 @@ class DeviceCachingAllocator {
|
||||
std::atomic<CreateContextFn> context_recorder_;
|
||||
RecordContext record_context_ = RecordContext::NEVER;
|
||||
RingBuffer<TraceEntry> alloc_buffer;
|
||||
std::unordered_set<TraceEntry::Action> skip_actions_list;
|
||||
std::vector<std::pair<MempoolId_t, std::function<bool(sycl::queue*)>>>
|
||||
captures_underway;
|
||||
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
|
||||
@@ -1544,6 +1545,9 @@ class DeviceCachingAllocator {
|
||||
std::shared_ptr<GatheredContext> context) {
|
||||
if (!record_history)
|
||||
return;
|
||||
bool should_skip = skip_actions_list.count(action) > 0;
|
||||
if (should_skip)
|
||||
return;
|
||||
TraceEntry te(
|
||||
action,
|
||||
device,
|
||||
@@ -1653,10 +1657,32 @@ class DeviceCachingAllocator {
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_buffer_max_entries,
|
||||
RecordContext when,
|
||||
bool clearHistory) {
|
||||
bool clearHistory,
|
||||
const std::vector<std::string>& skip_actions) {
|
||||
std::unique_lock<std::recursive_mutex> lock(mutex);
|
||||
TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
|
||||
record_history = enabled;
|
||||
|
||||
static const std::unordered_map<std::string, TraceEntry::Action>
|
||||
kActionMap = {
|
||||
{"alloc", TraceEntry::Action::ALLOC},
|
||||
{"free_requested", TraceEntry::Action::FREE_REQUESTED},
|
||||
{"free_completed", TraceEntry::Action::FREE_COMPLETED},
|
||||
{"segment_alloc", TraceEntry::Action::SEGMENT_ALLOC},
|
||||
{"segment_free", TraceEntry::Action::SEGMENT_FREE},
|
||||
{"segment_map", TraceEntry::Action::SEGMENT_MAP},
|
||||
{"segment_unmap", TraceEntry::Action::SEGMENT_UNMAP},
|
||||
{"snapshot", TraceEntry::Action::SNAPSHOT},
|
||||
{"oom", TraceEntry::Action::OOM},
|
||||
};
|
||||
|
||||
skip_actions_list.clear();
|
||||
for (const auto& action_str : skip_actions) {
|
||||
auto it = kActionMap.find(action_str);
|
||||
TORCH_CHECK(it != kActionMap.end(), "Unknown skip action: ", action_str);
|
||||
skip_actions_list.insert(it->second);
|
||||
}
|
||||
|
||||
context_recorder_.store(record_history ? context_recorder : nullptr);
|
||||
alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
|
||||
record_context_ = enabled ? when : RecordContext::NEVER;
|
||||
@@ -1977,14 +2003,16 @@ class NativeCachingAllocator : public XPUAllocator {
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_buffer_max_entries,
|
||||
RecordContext when,
|
||||
bool clearHistory) {
|
||||
bool clearHistory,
|
||||
const std::vector<std::string>& skip_actions) {
|
||||
for (auto& allocator : device_allocators) {
|
||||
allocator->recordHistory(
|
||||
enabled,
|
||||
context_recorder,
|
||||
alloc_buffer_max_entries,
|
||||
when,
|
||||
clearHistory);
|
||||
clearHistory,
|
||||
skip_actions);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2056,9 +2084,15 @@ void recordHistory(
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
RecordContext when,
|
||||
bool clearHistory) {
|
||||
bool clearHistory,
|
||||
const std::vector<std::string>& skip_actions) {
|
||||
native_allocator.recordHistory(
|
||||
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
|
||||
enabled,
|
||||
context_recorder,
|
||||
alloc_trace_max_entries,
|
||||
when,
|
||||
clearHistory,
|
||||
skip_actions);
|
||||
}
|
||||
|
||||
SnapshotInfo snapshot(MempoolId_t mempool_id) {
|
||||
|
||||
@@ -161,7 +161,8 @@ C10_XPU_API void recordHistory(
|
||||
CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
RecordContext when,
|
||||
bool clearHistory);
|
||||
bool clearHistory,
|
||||
const std::vector<std::string>& skip_actions);
|
||||
|
||||
C10_XPU_API SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user