mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Fixes #164663 ## Issue The torch model with multiple layers that is wrapped with fsdp2 registers pre and post forward hooks in a group using `_MultiHandler`. This becomes an issue during the context manager of the tracker where the hooks are reset and replaced. The hooks are all using the same fsdp state pointer so one reset will reset all. So when the output layer was modified with a new pre and post forward hook it would delete the previous layer's initialization causing `KeyError` for the Norm layer as it is nonexistent. ## The Fix Check to see if there are multiple `_MultiHandler` objects and `RemoveHandler` objects and only execute the remove hook once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165662 Approved by: https://github.com/sanketpurandare