[dynamo] Introduce _set_lru_cache (#167038)

Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926. This PR can't be defaulted on, that would be terrible for cache look up times.

There's a proper fix in the works by @williamwen42.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038
Approved by: https://github.com/williamwen42
This commit is contained in:
Simon Fan
2025-11-04 19:59:19 -08:00
committed by PyTorch MergeBot
parent edd8d356b6
commit 0b4dd08e04
4 changed files with 114 additions and 3 deletions

View File

@@ -48,6 +48,7 @@ from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
expectedFailureDynamic,
rand_strided,
same,
skipIfNotPy312,
@@ -7455,6 +7456,93 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
msg,
)
@expectedFailureDynamic
def test_dynamo_default_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[1], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[0]
# Step 3: Run with Step 1's inputs
# LRU cache will match against dynamic shape graph first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], dynamic_shapes_cache_entry)
self.assertEqual(c[1], static_shapes_cache_entry)
@expectedFailureDynamic
def test_dynamo_disable_lru_cache_behavior(self):
@torch.compile(backend="eager")
def fn(x):
return x + 10
def run():
torch._dynamo.reset()
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
# Step 1: Compile a static shapes graph
x = torch.randn(10, 10)
fn(x)
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(a), 1)
static_shapes_cache_entry = a[0]
# Step 2: Compile a dynamic shapes graph
y = torch.randn(20, 20)
fn(y)
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(b), 2)
self.assertEqual(b[0], static_shapes_cache_entry)
dynamic_shapes_cache_entry = b[1]
# Step 3: Run with Step 1's inputs
# LRU cache is disabled, we should still have static entry first
fn(x)
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
fn._torchdynamo_orig_callable.__code__
)
self.assertEqual(len(c), 2)
self.assertEqual(c[0], static_shapes_cache_entry)
self.assertEqual(c[1], dynamic_shapes_cache_entry)
try:
torch._C._dynamo.eval_frame._set_lru_cache(False)
run()
finally:
torch._C._dynamo.eval_frame._set_lru_cache(True)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@@ -13,6 +13,11 @@
#define _PyCode_SetExtra PyUnstable_Code_SetExtra
#endif
namespace {
// Short-term fix for: https://github.com/pytorch/pytorch/issues/166926
bool use_lru = true;
} // namespace
Py_ssize_t extra_index = -1;
CacheEntry* ExtraState::get_first_entry() {
@@ -190,7 +195,9 @@ void lookup(
++index;
}
if (found) {
extra_state->move_to_front(found);
if (use_lru) {
extra_state->move_to_front(found);
}
*maybe_cached_code = found->code.ptr();
*trace_annotation = found->trace_annotation.c_str();
return;
@@ -202,8 +209,14 @@ CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code,
PyObject* backend) {
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
auto new_iter = extra_state->cache_entry_list.begin();
std::list<CacheEntry>::iterator new_iter;
if (use_lru) {
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
new_iter = extra_state->cache_entry_list.begin();
} else {
extra_state->cache_entry_list.emplace_back(guarded_code, backend);
new_iter = std::prev(extra_state->cache_entry_list.end());
}
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
// Set guard_manager references to extra_state and CacheEntry
@@ -269,6 +282,14 @@ void _load_precompile_entry(
extra->precompile_entries.push_back(std::move(entry));
}
void _set_lru_cache(py::object boolean) {
if (py::cast<bool>(boolean)) {
use_lru = true;
} else {
use_lru = false;
}
}
py::list _debug_get_precompile_entries(const py::handle& code_obj) {
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
throw py::type_error("expected a code object!");

View File

@@ -203,5 +203,6 @@ void _load_precompile_entry(
py::object guard_manager,
py::object dynamo_code);
py::list _debug_get_precompile_entries(const py::handle& code_obj);
void _set_lru_cache(py::object boolean);
#endif

View File

@@ -254,6 +254,7 @@ void initDynamoBindings(PyObject* torch) {
m.def("_reset_precompile_entries", &_reset_precompile_entries);
m.def("_load_precompile_entry", &_load_precompile_entry);
m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries);
m.def("_set_lru_cache", &_set_lru_cache);
py::bind_vector<std::vector<uint8_t>>(m, "VectorUInt8");
init_THPCaches();
if (THP_PyOpcode_Caches != nullptr) {