mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[dynamo] Introduce _set_lru_cache (#167038)
Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926. This PR can't be defaulted on, that would be terrible for cache look up times. There's a proper fix in the works by @williamwen42. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
edd8d356b6
commit
0b4dd08e04
@@ -48,6 +48,7 @@ from torch._dynamo.testing import (
|
||||
CompileCounter,
|
||||
CompileCounterWithBackend,
|
||||
EagerAndRecordGraphs,
|
||||
expectedFailureDynamic,
|
||||
rand_strided,
|
||||
same,
|
||||
skipIfNotPy312,
|
||||
@@ -7455,6 +7456,93 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
msg,
|
||||
)
|
||||
|
||||
@expectedFailureDynamic
|
||||
def test_dynamo_default_lru_cache_behavior(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x + 10
|
||||
|
||||
torch._dynamo.reset()
|
||||
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
|
||||
# Step 1: Compile a static shapes graph
|
||||
x = torch.randn(10, 10)
|
||||
fn(x)
|
||||
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(a), 1)
|
||||
static_shapes_cache_entry = a[0]
|
||||
|
||||
# Step 2: Compile a dynamic shapes graph
|
||||
y = torch.randn(20, 20)
|
||||
fn(y)
|
||||
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(b), 2)
|
||||
self.assertEqual(b[1], static_shapes_cache_entry)
|
||||
dynamic_shapes_cache_entry = b[0]
|
||||
|
||||
# Step 3: Run with Step 1's inputs
|
||||
# LRU cache will match against dynamic shape graph first
|
||||
fn(x)
|
||||
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(c), 2)
|
||||
self.assertEqual(c[0], dynamic_shapes_cache_entry)
|
||||
self.assertEqual(c[1], static_shapes_cache_entry)
|
||||
|
||||
@expectedFailureDynamic
|
||||
def test_dynamo_disable_lru_cache_behavior(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x + 10
|
||||
|
||||
def run():
|
||||
torch._dynamo.reset()
|
||||
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
|
||||
# Step 1: Compile a static shapes graph
|
||||
x = torch.randn(10, 10)
|
||||
fn(x)
|
||||
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(a), 1)
|
||||
static_shapes_cache_entry = a[0]
|
||||
|
||||
# Step 2: Compile a dynamic shapes graph
|
||||
y = torch.randn(20, 20)
|
||||
fn(y)
|
||||
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(b), 2)
|
||||
self.assertEqual(b[0], static_shapes_cache_entry)
|
||||
dynamic_shapes_cache_entry = b[1]
|
||||
|
||||
# Step 3: Run with Step 1's inputs
|
||||
# LRU cache is disabled, we should still have static entry first
|
||||
fn(x)
|
||||
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(c), 2)
|
||||
self.assertEqual(c[0], static_shapes_cache_entry)
|
||||
self.assertEqual(c[1], dynamic_shapes_cache_entry)
|
||||
|
||||
try:
|
||||
torch._C._dynamo.eval_frame._set_lru_cache(False)
|
||||
run()
|
||||
finally:
|
||||
torch._C._dynamo.eval_frame._set_lru_cache(True)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
||||
@@ -13,6 +13,11 @@
|
||||
#define _PyCode_SetExtra PyUnstable_Code_SetExtra
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
// Short-term fix for: https://github.com/pytorch/pytorch/issues/166926
|
||||
bool use_lru = true;
|
||||
} // namespace
|
||||
|
||||
Py_ssize_t extra_index = -1;
|
||||
|
||||
CacheEntry* ExtraState::get_first_entry() {
|
||||
@@ -190,7 +195,9 @@ void lookup(
|
||||
++index;
|
||||
}
|
||||
if (found) {
|
||||
extra_state->move_to_front(found);
|
||||
if (use_lru) {
|
||||
extra_state->move_to_front(found);
|
||||
}
|
||||
*maybe_cached_code = found->code.ptr();
|
||||
*trace_annotation = found->trace_annotation.c_str();
|
||||
return;
|
||||
@@ -202,8 +209,14 @@ CacheEntry* create_cache_entry(
|
||||
ExtraState* extra_state,
|
||||
PyObject* guarded_code,
|
||||
PyObject* backend) {
|
||||
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
|
||||
auto new_iter = extra_state->cache_entry_list.begin();
|
||||
std::list<CacheEntry>::iterator new_iter;
|
||||
if (use_lru) {
|
||||
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
|
||||
new_iter = extra_state->cache_entry_list.begin();
|
||||
} else {
|
||||
extra_state->cache_entry_list.emplace_back(guarded_code, backend);
|
||||
new_iter = std::prev(extra_state->cache_entry_list.end());
|
||||
}
|
||||
new_iter->_owner = extra_state;
|
||||
new_iter->_owner_loc = new_iter;
|
||||
// Set guard_manager references to extra_state and CacheEntry
|
||||
@@ -269,6 +282,14 @@ void _load_precompile_entry(
|
||||
extra->precompile_entries.push_back(std::move(entry));
|
||||
}
|
||||
|
||||
void _set_lru_cache(py::object boolean) {
|
||||
if (py::cast<bool>(boolean)) {
|
||||
use_lru = true;
|
||||
} else {
|
||||
use_lru = false;
|
||||
}
|
||||
}
|
||||
|
||||
py::list _debug_get_precompile_entries(const py::handle& code_obj) {
|
||||
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
|
||||
throw py::type_error("expected a code object!");
|
||||
|
||||
@@ -203,5 +203,6 @@ void _load_precompile_entry(
|
||||
py::object guard_manager,
|
||||
py::object dynamo_code);
|
||||
py::list _debug_get_precompile_entries(const py::handle& code_obj);
|
||||
void _set_lru_cache(py::object boolean);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -254,6 +254,7 @@ void initDynamoBindings(PyObject* torch) {
|
||||
m.def("_reset_precompile_entries", &_reset_precompile_entries);
|
||||
m.def("_load_precompile_entry", &_load_precompile_entry);
|
||||
m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries);
|
||||
m.def("_set_lru_cache", &_set_lru_cache);
|
||||
py::bind_vector<std::vector<uint8_t>>(m, "VectorUInt8");
|
||||
init_THPCaches();
|
||||
if (THP_PyOpcode_Caches != nullptr) {
|
||||
|
||||
Reference in New Issue
Block a user