[dynamo][guards] Install dict watchers for recrusive dict tag optimization (#159796)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159796
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-08-11 22:09:51 -07:00
committed by PyTorch MergeBot
parent f990490a23
commit 4d5b3f2d5a

View File

@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
static int dict_version_watcher_id;
static int dict_recursive_tag_watcher_id;
static uint64_t global_dict_version_id = 1;
static int dict_version_watch_callback(
PyDict_WatchEvent event,
@@ -1557,6 +1558,37 @@ class GuardManager;
class RootGuardManager;
class DictGuardManager;
// Global registry used by the *recursive-dict-tag* optimisation.
//
// Key : `PyObject*` pointing to a watched `dict`
// Value : list of `GuardManager*` instances that have recorded that dict
//
// Why is this global?
// -------------------
// * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
// All `GuardManager`s therefore share a single watcher callback.
// * Different guard managers (possibly across different frames) can end up
// watching the same dictionary pointer. Therefore, we have a list of guard
// managers for each dict pointer.
//
// When is watch registered?
// * During the recording phase of recursive dict tag matching in GuardManager.
//
// When are they watched?
// * In the dict_recursive_tag_watch_callback function.
//
// When are the dict pointers unwatched?
// * If a dict is mutated or the guard manager deallocates.
// * Read `unwatch_all_saved_dict_pointers` docstring for more details.
//
// Expected size
// -------------
// Every compilation frame contributes its tag-safe dicts to this registry, so
// the container can grow large over the lifetime of the process. Thats
// acceptable: lookup is by pointer (hash/equals = identity) and each entry
// stores only lightweight pointers.
std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_managers;
/**
* Base class for the leaf guard in the GuardManager hierarchy.
*/
@@ -2625,6 +2657,7 @@ class GuardManager {
virtual ~GuardManager() {
cleanup_tag_safe_entries();
disable_recursive_dict_tag_optimization();
}
void cleanup_tag_safe_entries() {
@@ -2727,6 +2760,11 @@ class GuardManager {
_tensor_pointers[value] = tensor_pointers;
}
void disable_recursive_dict_tag_optimization() {
unwatch_all_saved_dict_pointers();
_disable_dict_tag_matching = true;
}
public:
// For cloning
GuardManager(
@@ -2833,6 +2871,10 @@ class GuardManager {
}
bool check_dict_pointer_tags(PyObject* value) {
if (_dict_callback_installed) {
// This means that for 3.12+, there are callbacks watching dict pointers.
return true;
}
for (auto& kv : _dict_pointers[value]) {
PyObject* dict_pointer = kv.first;
uint64_t old_tag = kv.second;
@@ -2963,6 +3005,11 @@ class GuardManager {
throw std::runtime_error(
"Could not register a callback for recursive dict tag optimization");
}
#if IS_PYTHON_3_12_PLUS
// Ideally we don't need to even register a weakref callback for value.
// But it does not hurt to be more cautious
_dict_callback_installed = watch_dict_pointers(value);
#endif
}
}
if (!result) {
@@ -2979,8 +3026,9 @@ class GuardManager {
}
GuardManager* guard_manager = static_cast<GuardManager*>(
PyCapsule_GetPointer(self_capsule, "GuardManager*"));
if (guard_manager)
guard_manager->_disable_dict_tag_matching = true;
if (guard_manager) {
guard_manager->disable_recursive_dict_tag_optimization();
}
Py_RETURN_NONE;
}
@@ -3031,6 +3079,81 @@ class GuardManager {
return true;
}
bool watch_dict_pointers(PyObject* value) {
#if IS_PYTHON_3_12_PLUS
// -----------------------------------------------------------------------------
// CPython 3.12 dict-watcher integration
// -----------------------------------------------------------------------------
//
// We register a single watcher on all every dictionary pointer recorded by
// a tag-safe root. The watcher callback fires *once* for any structural
// change to those dictionaries
//
// Fast-path benefit
// -----------------
// In steady state we no longer need to iterate over the recorded
// dictionaries and compare their `ma_version_tag`s (the
// “are-tags-unchanged” loop that used to dominate the fast-path guard
// evaluation). The presence of an *active watcher* is itself a guarantee
// that none of the dicts has mutated; if one **does** mutate, the callback
// simply flips `_disable_dict_tag_matching = true`, causing the next guard
// evaluation to skip the recursive-dict-tag optimisation entirely.
for (auto& kv : _dict_pointers[value]) {
PyObject* dict_pointer = kv.first;
int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer);
if (rc != 0) {
PyErr_Clear();
return false;
}
dict_to_guard_managers[dict_pointer].push_back(this);
}
#endif
return true;
}
void unwatch_all_saved_dict_pointers() {
/*
We may have recorded hundreds/thousands of dict pointers for the recursive
dict-tag optimisation. If any of those dicts mutates, we want to disable the
optimisation and then unwatch as many dict pointers as we can.
Be careful: the same dict pointer can be recorded by multiple GuardManagers.
So the flow is:
1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer].
2) If the list for that dict becomes empty, then:
- PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer)
- erase the dict_pointer entry from dict_to_guard_managers.
*/
#if IS_PYTHON_3_12_PLUS
if (!_disable_dict_tag_matching) {
for (auto& value_stashed_pointers : _dict_pointers) {
auto stashed_pointers = value_stashed_pointers.second;
for (auto& stashed_pointer : stashed_pointers) {
PyObject* dict_pointer = stashed_pointer.first;
// Delete the guard manager from the dict_to_guard_managers
auto it = std::find(
dict_to_guard_managers[dict_pointer].begin(),
dict_to_guard_managers[dict_pointer].end(),
this);
if (it != dict_to_guard_managers[dict_pointer].end()) {
dict_to_guard_managers[dict_pointer].erase(it);
}
// Unwatch the dict pointer if this was the last guard manager
// watching it.
if (dict_to_guard_managers[dict_pointer].empty()) {
PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer);
dict_to_guard_managers.erase(dict_pointer);
}
}
}
}
#endif
}
virtual bool check_nopybind(FrameLocalsMapping* value) {
return check_nopybind_template(value);
}
@@ -3270,6 +3393,9 @@ class GuardManager {
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
std::vector<WeakEntry> _tag_safe_entries;
// 3.12+ related helper
bool _dict_callback_installed = false;
protected:
// weakref to the type of guarded value
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root(
root->add_relational_guard_resetter(std::move(guard));
}
#if IS_PYTHON_3_12_PLUS
static int dict_recursive_tag_watch_callback(
PyDict_WatchEvent event,
PyObject* dict,
PyObject* key,
PyObject* new_value) noexcept {
if (event != PyDict_EVENT_CLONED) {
auto it = dict_to_guard_managers.find(dict);
if (it != dict_to_guard_managers.end()) {
auto guard_managers = it->second;
for (auto& guard_manager : guard_managers) {
if (guard_manager) {
guard_manager->disable_recursive_dict_tag_optimization();
}
}
}
}
return 0; // keep watching
}
#endif
std::unique_ptr<GuardManager> make_guard_manager(
RootGuardManager* root,
std::string source,
@@ -7558,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() {
throw std::runtime_error("Failed to install dict_version_watch_callback");
}
dict_recursive_tag_watcher_id =
PyDict_AddWatcher(dict_recursive_tag_watch_callback);
if (dict_recursive_tag_watcher_id == -1) {
throw std::runtime_error(
"Failed to install dict_recursive_tag_watch_callback");
}
#endif
return m;