diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 9e25d07b1e8..c8e0ae9c273 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { static std::unordered_map dict_version_map; static int dict_version_watcher_id; +static int dict_recursive_tag_watcher_id; static uint64_t global_dict_version_id = 1; static int dict_version_watch_callback( PyDict_WatchEvent event, @@ -1557,6 +1558,37 @@ class GuardManager; class RootGuardManager; class DictGuardManager; +// Global registry used by the *recursive-dict-tag* optimisation. +// +// Key : `PyObject*` pointing to a watched `dict` +// Value : list of `GuardManager*` instances that have recorded that dict +// +// Why is this global? +// ------------------- +// * CPython allows only a small, fixed number of dict-watcher IDs (≈64). +// All `GuardManager`s therefore share a single watcher callback. +// * Different guard managers (possibly across different frames) can end up +// watching the same dictionary pointer. Therefore, we have a list of guard +// managers for each dict pointer. +// +// When is watch registered? +// * During the recording phase of recursive dict tag matching in GuardManager. +// +// When are they watched? +// * In the dict_recursive_tag_watch_callback function. +// +// When are the dict pointers unwatched? +// * If a dict is mutated or the guard manager deallocates. +// * Read `unwatch_all_saved_dict_pointers` docstring for more details. +// +// Expected size +// ------------- +// Every compilation frame contributes its tag-safe dicts to this registry, so +// the container can grow large over the lifetime of the process. That’s +// acceptable: lookup is by pointer (hash/equals = identity) and each entry +// stores only lightweight pointers. +std::unordered_map> dict_to_guard_managers; + /** * Base class for the leaf guard in the GuardManager hierarchy. */ @@ -2625,6 +2657,7 @@ class GuardManager { virtual ~GuardManager() { cleanup_tag_safe_entries(); + disable_recursive_dict_tag_optimization(); } void cleanup_tag_safe_entries() { @@ -2727,6 +2760,11 @@ class GuardManager { _tensor_pointers[value] = tensor_pointers; } + void disable_recursive_dict_tag_optimization() { + unwatch_all_saved_dict_pointers(); + _disable_dict_tag_matching = true; + } + public: // For cloning GuardManager( @@ -2833,6 +2871,10 @@ class GuardManager { } bool check_dict_pointer_tags(PyObject* value) { + if (_dict_callback_installed) { + // This means that for 3.12+, there are callbacks watching dict pointers. + return true; + } for (auto& kv : _dict_pointers[value]) { PyObject* dict_pointer = kv.first; uint64_t old_tag = kv.second; @@ -2963,6 +3005,11 @@ class GuardManager { throw std::runtime_error( "Could not register a callback for recursive dict tag optimization"); } +#if IS_PYTHON_3_12_PLUS + // Ideally we don't need to even register a weakref callback for value. + // But it does not hurt to be more cautious + _dict_callback_installed = watch_dict_pointers(value); +#endif } } if (!result) { @@ -2979,8 +3026,9 @@ class GuardManager { } GuardManager* guard_manager = static_cast( PyCapsule_GetPointer(self_capsule, "GuardManager*")); - if (guard_manager) - guard_manager->_disable_dict_tag_matching = true; + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } Py_RETURN_NONE; } @@ -3031,6 +3079,81 @@ class GuardManager { return true; } + bool watch_dict_pointers(PyObject* value) { +#if IS_PYTHON_3_12_PLUS + // ----------------------------------------------------------------------------- + // CPython 3.12 dict-watcher integration + // ----------------------------------------------------------------------------- + // + // We register a single watcher on all every dictionary pointer recorded by + // a tag-safe root. The watcher callback fires *once* for any structural + // change to those dictionaries + // + // Fast-path benefit + // ----------------- + // In steady state we no longer need to iterate over the recorded + // dictionaries and compare their `ma_version_tag`s (the + // “are-tags-unchanged” loop that used to dominate the fast-path guard + // evaluation). The presence of an *active watcher* is itself a guarantee + // that none of the dicts has mutated; if one **does** mutate, the callback + // simply flips `_disable_dict_tag_matching = true`, causing the next guard + // evaluation to skip the recursive-dict-tag optimisation entirely. + for (auto& kv : _dict_pointers[value]) { + PyObject* dict_pointer = kv.first; + int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer); + if (rc != 0) { + PyErr_Clear(); + return false; + } + dict_to_guard_managers[dict_pointer].push_back(this); + } +#endif + return true; + } + + void unwatch_all_saved_dict_pointers() { + /* + We may have recorded hundreds/thousands of dict pointers for the recursive + dict-tag optimisation. If any of those dicts mutates, we want to disable the + optimisation and then unwatch as many dict pointers as we can. + + Be careful: the same dict pointer can be recorded by multiple GuardManagers. + So the flow is: + + 1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer]. + 2) If the list for that dict becomes empty, then: + - PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer) + - erase the dict_pointer entry from dict_to_guard_managers. + */ +#if IS_PYTHON_3_12_PLUS + if (!_disable_dict_tag_matching) { + for (auto& value_stashed_pointers : _dict_pointers) { + auto stashed_pointers = value_stashed_pointers.second; + + for (auto& stashed_pointer : stashed_pointers) { + PyObject* dict_pointer = stashed_pointer.first; + + // Delete the guard manager from the dict_to_guard_managers + auto it = std::find( + dict_to_guard_managers[dict_pointer].begin(), + dict_to_guard_managers[dict_pointer].end(), + this); + if (it != dict_to_guard_managers[dict_pointer].end()) { + dict_to_guard_managers[dict_pointer].erase(it); + } + + // Unwatch the dict pointer if this was the last guard manager + // watching it. + if (dict_to_guard_managers[dict_pointer].empty()) { + PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer); + dict_to_guard_managers.erase(dict_pointer); + } + } + } + } +#endif + } + virtual bool check_nopybind(FrameLocalsMapping* value) { return check_nopybind_template(value); } @@ -3270,6 +3393,9 @@ class GuardManager { std::unordered_map> _tensor_pointers; std::vector _tag_safe_entries; + // 3.12+ related helper + bool _dict_callback_installed = false; + protected: // weakref to the type of guarded value // protected because it is used for cloning by DictGuardManager @@ -3957,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root( root->add_relational_guard_resetter(std::move(guard)); } +#if IS_PYTHON_3_12_PLUS +static int dict_recursive_tag_watch_callback( + PyDict_WatchEvent event, + PyObject* dict, + PyObject* key, + PyObject* new_value) noexcept { + if (event != PyDict_EVENT_CLONED) { + auto it = dict_to_guard_managers.find(dict); + if (it != dict_to_guard_managers.end()) { + auto guard_managers = it->second; + for (auto& guard_manager : guard_managers) { + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(); + } + } + } + } + return 0; // keep watching +} +#endif + std::unique_ptr make_guard_manager( RootGuardManager* root, std::string source, @@ -7558,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() { throw std::runtime_error("Failed to install dict_version_watch_callback"); } + dict_recursive_tag_watcher_id = + PyDict_AddWatcher(dict_recursive_tag_watch_callback); + if (dict_recursive_tag_watcher_id == -1) { + throw std::runtime_error( + "Failed to install dict_recursive_tag_watch_callback"); + } + #endif return m;