mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Include one level of stack trace in the lru_cache warning msg (#171496)
Fixes #167991 Example of the new warning message: ```python /home/guilhermel/git/pytorch313/torch/_dynamo/variables/functions.py:2159: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function at 'script.py:12'. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS=+dynamo for a DEBUG stack trace. This call originates from: File "/path/to/script.py", line 12, in bar return baz(x) torch._dynamo.utils.warn_once(msg) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/171496 Approved by: https://github.com/Lucaskabela
This commit is contained in:
committed by
PyTorch MergeBot
parent
b35a75b73d
commit
39839dbc39
@@ -1588,6 +1588,26 @@ call to a lru_cache wrapped function at: test_error_messages.py:N
|
||||
""",
|
||||
)
|
||||
|
||||
def test_lru_cache_warning(self):
|
||||
# test only the warning message itself
|
||||
@lru_cache
|
||||
def bax(x):
|
||||
return x + 1
|
||||
|
||||
def bar(x):
|
||||
return bax(x)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def foo(x):
|
||||
return bar(x)
|
||||
|
||||
x = torch.randn(2)
|
||||
with self.assertWarnsOnceRegex(
|
||||
UserWarning,
|
||||
r"(?s).*This call originates from:\n.*File .*, line (\d+), in bar",
|
||||
):
|
||||
foo(x)
|
||||
|
||||
def test_disable_message(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def outer(fn, x):
|
||||
|
||||
@@ -26,6 +26,7 @@ import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
@@ -2131,12 +2132,18 @@ class WrapperUserFunctionVariable(VariableTracker):
|
||||
module_name = getattr(target_fn, "__module__", "") or ""
|
||||
|
||||
if module_name.split(".", maxsplit=1)[0] != "torch":
|
||||
frame_summary = tx.frame_summary()
|
||||
filename = os.path.basename(frame_summary.filename)
|
||||
lineno = frame_summary.lineno
|
||||
msg = (
|
||||
"Dynamo detected a call to a `functools.lru_cache`-wrapped "
|
||||
"function. Dynamo ignores the cache wrapper and directly "
|
||||
"traces the wrapped function. Silent incorrectness is only "
|
||||
"a *potential* risk, not something we have observed. "
|
||||
'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.'
|
||||
f"function at '{filename}:{lineno}'. Dynamo ignores the "
|
||||
"cache wrapper and directly traces the wrapped function. "
|
||||
"Silent incorrectness is only a *potential* risk, not "
|
||||
"something we have observed. "
|
||||
"Enable TORCH_LOGS=+dynamo for a DEBUG stack trace.\n\n"
|
||||
"This call originates from:\n"
|
||||
f"{''.join(traceback.format_list([frame_summary]))}"
|
||||
)
|
||||
|
||||
torch._dynamo.utils.warn_once(msg)
|
||||
|
||||
Reference in New Issue
Block a user