[inductor][caching] dump imc to disk in human readable format (#170236)

Summary:
tldr; added a method to dump Memoizer's memory to disk in JSON format at program exit

in addition, we now include the encoded params in the cache entry (since the keys are hashes they don't provide much clues as to the origin of the entry by themself, hence why the params are now included)

note: the testing for the new dump method is actually included in the subsequent diff

Test Plan: ```buck test fbcode//mode/opt caffe2/test/inductor:caching```

Differential Revision: D88984165

Pull Request resolved: https://github.com/pytorch/pytorch/pull/170236
Approved by: https://github.com/aorenste
This commit is contained in:
Nicolas Macchioni
2025-12-19 23:59:19 +00:00
committed by PyTorch MergeBot
parent bdd3df645b
commit 6919a909f3
8 changed files with 600 additions and 65 deletions

View File

@@ -1065,6 +1065,7 @@ exclusions = {
"compute_dependencies",
"annotation",
"node_runtime_estimation",
"caching",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@@ -2,6 +2,7 @@
# pyre-strict
from __future__ import annotations
import json
import os
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError, wait
from contextlib import contextmanager
@@ -1089,11 +1090,14 @@ class InterfacesTest(TestMixin, TestCase):
self.assertEqual(result, 10)
self.assertEqual(call_count, 1)
# Verify memory cache has the result
cache_key = interfaces._make_key(None, 5)
# Verify memory cache has the result as CacheEntry
cache_key = interfaces._BaseMemoizer._make_key(None, 5)
memory_hit = persistent._memoizer._cache.get(cache_key)
self.assertIsNotNone(memory_hit)
self.assertEqual(memory_hit.value, 10)
# Cache now stores CacheEntry with encoded_params and encoded_result
cache_entry = memory_hit.value
self.assertEqual(cache_entry.encoded_result, 10)
self.assertEqual(cache_entry.encoded_params, {"args": (5,), "kwargs": {}})
# Verify disk cache has the result (pickled)
disk_hit = persistent._disk_cache.get(cache_key)
@@ -1112,11 +1116,16 @@ class InterfacesTest(TestMixin, TestCase):
# Setup: create a persistent memoizer and store only to disk
persistent = interfaces.PersistentMemoizer(sub_dir=self.sub_dir())
# Store a value directly to disk cache only
cache_key = interfaces._make_key(None, 5)
# Store a value directly to disk cache only (as CacheEntry)
cache_key = interfaces._BaseMemoizer._make_key(None, 5)
import pickle
pickled_value = pickle.dumps(10)
# Cache now stores CacheEntry with encoded_params and encoded_result
cache_entry = interfaces.CacheEntry(
encoded_params={"args": (5,), "kwargs": {}},
encoded_result=10,
)
pickled_value = pickle.dumps(cache_entry)
persistent._disk_cache.insert(cache_key, pickled_value)
# Verify it's not in memory cache yet
@@ -1137,7 +1146,9 @@ class InterfacesTest(TestMixin, TestCase):
# Verify memory cache was populated from disk
memory_hit = persistent._memoizer._cache.get(cache_key)
self.assertIsNotNone(memory_hit)
self.assertEqual(memory_hit.value, 10)
# Memory cache should now contain the CacheEntry
cache_entry = memory_hit.value
self.assertEqual(cache_entry.encoded_result, 10)
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
@@ -1170,7 +1181,7 @@ class InterfacesTest(TestMixin, TestCase):
self.assertEqual(result2, 10)
# Clear memory cache to simulate a new process
cache_key = interfaces._make_key(None, 5)
cache_key = interfaces._BaseMemoizer._make_key(None, 5)
persistent._memoizer._cache = impls._InMemoryCacheImpl()
# Third call - memory miss, disk hit, populates memory
@@ -1210,7 +1221,7 @@ class InterfacesTest(TestMixin, TestCase):
self.assertEqual(result2, 10)
# Verify nothing was cached
cache_key = interfaces._make_key(None, 5)
cache_key = interfaces._BaseMemoizer._make_key(None, 5)
memory_hit = persistent._memoizer._cache.get(cache_key)
self.assertIsNone(memory_hit)
disk_hit = persistent._disk_cache.get(cache_key)
@@ -1263,6 +1274,249 @@ class InterfacesTest(TestMixin, TestCase):
self.assertEqual(result1, 10)
self.assertEqual(result2, 10)
# ============= Memoizer._dump_to_disk Tests =============
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_creates_json_file(self) -> None:
"""Test that _dump_to_disk creates a JSON file with cached entries.
Verifies that when _dump_to_disk is called, it creates a JSON file
containing the cached entries in a human-readable format.
"""
# Setup: create a memoizer and cache some values
memoizer = interfaces.Memoizer()
@memoizer.record()
def compute(x: int) -> int:
return x * 2
compute(5)
compute(10)
# Execute: dump the cache to disk
memoizer._dump_to_disk()
# Assert: JSON file was created with correct structure
self.assertTrue(memoizer._shared_cache_filepath.exists())
with open(memoizer._shared_cache_filepath) as f:
data = json.load(f)
self.assertIn("cache_entries", data)
self.assertIn("cache_size", data)
self.assertEqual(data["cache_size"], 2)
# Verify entries have correct format
for entry in data["cache_entries"].values():
self.assertIn("params", entry)
self.assertIn("result", entry)
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_with_sub_key(self) -> None:
"""Test that _dump_to_disk uses sub_key for nested structure.
Verifies that when a Memoizer is initialized with a sub_key,
the cache entries are stored under cache_entries[sub_key].
"""
# Setup: create a memoizer with sub_key and cache a value
sub_key = "test_sub_key"
memoizer = interfaces.Memoizer(sub_key=sub_key)
@memoizer.record()
def compute(x: int) -> int:
return x * 2
compute(5)
# Execute: dump the cache to disk
memoizer._dump_to_disk()
# Assert: entries are stored under the sub_key
with open(memoizer._shared_cache_filepath) as f:
data = json.load(f)
self.assertIn("cache_entries", data)
self.assertIn(sub_key, data["cache_entries"])
# The sub_key should contain the cache entries
sub_entries = data["cache_entries"][sub_key]
self.assertEqual(len(sub_entries), 1)
# Verify entry format
for entry in sub_entries.values():
self.assertIn("params", entry)
self.assertIn("result", entry)
self.assertEqual(entry["result"], 10)
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_merges_with_existing(self) -> None:
"""Test that _dump_to_disk merges with existing cache data.
Verifies that when multiple Memoizer instances dump to the same file,
their entries are merged additively.
"""
# Setup: create first memoizer and cache a value
memoizer1 = interfaces.Memoizer()
@memoizer1.record()
def compute1(x: int) -> int:
return x * 2
compute1(5)
memoizer1._dump_to_disk()
# Create second memoizer and cache a different value
memoizer2 = interfaces.Memoizer()
@memoizer2.record()
def compute2(x: int) -> int:
return x * 3
compute2(10)
# Execute: dump second memoizer to disk
memoizer2._dump_to_disk()
# Assert: both entries are in the file
with open(memoizer1._shared_cache_filepath) as f:
data = json.load(f)
self.assertEqual(data["cache_size"], 2)
self.assertEqual(len(data["cache_entries"]), 2)
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_skips_empty_cache(self) -> None:
"""Test that _dump_to_disk does nothing when cache is empty.
Verifies that when _dump_to_disk is called on an empty cache,
no file is created.
"""
# Setup: create a memoizer with no cached values
memoizer = interfaces.Memoizer()
# Ensure the file doesn't exist beforehand
if memoizer._shared_cache_filepath.exists():
memoizer._shared_cache_filepath.unlink()
# Execute: dump the empty cache
memoizer._dump_to_disk()
# Assert: no file was created
self.assertFalse(memoizer._shared_cache_filepath.exists())
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_handles_corrupt_file(self) -> None:
"""Test that _dump_to_disk handles corrupt JSON files gracefully.
Verifies that when the existing cache file contains invalid JSON,
_dump_to_disk starts fresh and overwrites the corrupt file.
"""
# Setup: create a memoizer and cache a value
memoizer = interfaces.Memoizer()
@memoizer.record()
def compute(x: int) -> int:
return x * 2
compute(5)
# Create a corrupt JSON file
memoizer._shared_cache_filepath.parent.mkdir(parents=True, exist_ok=True)
with open(memoizer._shared_cache_filepath, "w") as f:
f.write("{ invalid json content")
# Execute: dump the cache (should handle corrupt file)
memoizer._dump_to_disk()
# Assert: file now contains valid JSON with our entry
with open(memoizer._shared_cache_filepath) as f:
data = json.load(f)
self.assertIn("cache_entries", data)
self.assertEqual(data["cache_size"], 1)
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_stores_encoded_params_and_result(self) -> None:
"""Test that _dump_to_disk stores both encoded params and result.
Verifies that the dumped JSON contains both the encoded parameters
and the encoded result for each cache entry, making it useful for debugging.
"""
# Setup: create a memoizer with custom encoder and cache a value
memoizer = interfaces.Memoizer()
@memoizer.record()
def compute(x: int, y: int) -> int:
return x + y
compute(5, 10)
# Execute: dump the cache to disk
memoizer._dump_to_disk()
# Assert: entry contains both params and result
with open(memoizer._shared_cache_filepath) as f:
data = json.load(f)
# Get the single entry
entries = data["cache_entries"]
self.assertEqual(len(entries), 1)
entry = next(iter(entries.values()))
self.assertEqual(entry["result"], 15)
self.assertEqual(entry["params"]["args"], [5, 10])
self.assertEqual(entry["params"]["kwargs"], {})
@patch_on_disk_cache_base_dir
@set_caching_module_enabled(True)
def test_memoizer_dump_to_disk_multiple_sub_keys(self) -> None:
"""Test that multiple Memoizers with different sub_keys coexist.
Verifies that Memoizers with different sub_keys store their entries
under separate namespaces in the same JSON file.
"""
# Setup: create two memoizers with different sub_keys
memoizer1 = interfaces.Memoizer(sub_key="feature_a")
memoizer2 = interfaces.Memoizer(sub_key="feature_b")
@memoizer1.record()
def compute_a(x: int) -> int:
return x * 2
@memoizer2.record()
def compute_b(x: int) -> int:
return x * 3
compute_a(5)
compute_b(10)
# Execute: dump both caches
memoizer1._dump_to_disk()
memoizer2._dump_to_disk()
# Assert: both sub_keys exist with their respective entries
with open(memoizer1._shared_cache_filepath) as f:
data = json.load(f)
self.assertIn("feature_a", data["cache_entries"])
self.assertIn("feature_b", data["cache_entries"])
# Verify each sub_key has one entry with correct result
feature_a_entries = data["cache_entries"]["feature_a"]
feature_b_entries = data["cache_entries"]["feature_b"]
self.assertEqual(len(feature_a_entries), 1)
self.assertEqual(len(feature_b_entries), 1)
self.assertEqual(next(iter(feature_a_entries.values()))["result"], 10)
self.assertEqual(next(iter(feature_b_entries.values()))["result"], 30)
if __name__ == "__main__":
run_tests()

View File

@@ -70,3 +70,14 @@ IS_CACHING_MODULE_ENABLED: Callable[[], bool] = partial(
_CACHING_MODULE_OSS_DEFAULT,
_CACHING_MODULE_ENV_VAR_OVERRIDE,
)
# Controls whether the Memoizer dumps its cache to a JSON file on destruction.
# This is useful for debugging and inspection purposes.
_DUMP_MEMOIZER_CACHE_ENV_VAR: str = "TORCHINDUCTOR_DUMP_MEMOIZER_CACHE"
_DUMP_MEMOIZER_CACHE_DEFAULT: bool = False
IS_DUMP_MEMOIZER_CACHE_ENABLED: Callable[[], bool] = partial(
_env_var_config,
_DUMP_MEMOIZER_CACHE_ENV_VAR,
_DUMP_MEMOIZER_CACHE_DEFAULT,
)

View File

@@ -9,7 +9,7 @@ for user-facing errors that also inherit from TypeError for compatibility.
from threading import Lock
from filelock import FileLock
from filelock import BaseFileLock
class CacheError(Exception):
@@ -53,7 +53,7 @@ class FileLockTimeoutError(SystemError):
limit, indicating that the lock could not be acquired within the allotted time.
"""
def __init__(self, flock: FileLock, timeout: float) -> None:
def __init__(self, flock: BaseFileLock, timeout: float) -> None:
"""Initialize the file lock timeout error with detailed lock information.
Args:

View File

@@ -6,16 +6,28 @@ This module provides high-level caching interfaces for memoization and
result caching functionality.
"""
import atexit
import functools
import json
import logging
import pickle
from collections.abc import Callable
from dataclasses import dataclass
from hashlib import sha256
from os import PathLike
from typing import cast
from pathlib import Path
from typing import cast, TypedDict
from typing_extensions import ParamSpec, TypeVar
from . import config, implementations
from filelock import FileLock
import torch
from torch._inductor.runtime.runtime_utils import cache_dir
from . import config, implementations, locks
logger = torch._logging.getArtifactLogger(__name__, "caching")
# Type variable for function parameters
_P = ParamSpec("_P")
@@ -25,36 +37,51 @@ _R = TypeVar("_R")
_EncodedR = TypeVar("_EncodedR")
def _make_key(
custom_params_encoder: Callable[..., object] | None,
*args: object,
**kwargs: object,
) -> str:
"""Generate a cache key from function parameters.
class CacheDumpEntry(TypedDict):
"""A single cache entry in the dump format.
Args:
custom_params_encoder: Optional encoder to apply to function parameters.
If None, params are pickled directly.
*args: Positional arguments to encode.
**kwargs: Keyword arguments to encode.
Returns:
A 32-character hex string suitable for use as a cache key.
Attributes:
params: The encoded function parameters.
result: The encoded function result.
"""
if custom_params_encoder is None:
# Pickle the parameters directly
pickled_params: bytes = pickle.dumps((args, kwargs))
else:
# Encode the parameters using the custom encoder
encoded_params = custom_params_encoder(*args, **kwargs)
# Pickle the encoded output
pickled_params = pickle.dumps(encoded_params)
# Hash the pickled bytes with SHA256
hash_obj = sha256(pickled_params)
params: object
result: object
# Get hex digest and truncate to 32 characters
return hash_obj.hexdigest()[:32]
class CacheDump(TypedDict):
"""The structure of the memoizer cache dump file.
The cache_entries field contains either:
- Direct entries: {cache_key: CacheDumpEntry} when no sub_key is used
- Nested entries: {sub_key: {cache_key: CacheDumpEntry}} when sub_key is set
Multiple Memoizer instances with different sub_keys can coexist in the same file.
Attributes:
cache_entries: Dictionary mapping cache keys (or sub_keys) to cache entries.
cache_size: The total number of cache entries. When sub_keys are used,
this is the sum of entries across all sub_keys.
"""
cache_entries: dict[str, CacheDumpEntry | dict[str, CacheDumpEntry]]
cache_size: int
@dataclass
class CacheEntry:
"""A cache entry containing encoded parameters and result.
This dataclass stores the encoded form of function parameters and
the (possibly encoded) result for human-readable cache dumps.
Attributes:
encoded_params: The encoded function parameters used for debugging/inspection.
encoded_result: The (possibly encoded) result of the function call.
"""
encoded_params: object
encoded_result: object
class _BaseMemoizer:
@@ -64,6 +91,38 @@ class _BaseMemoizer:
record and replay functionality.
"""
@staticmethod
def _make_key(
custom_params_encoder: Callable[..., object] | None,
*args: object,
**kwargs: object,
) -> str:
"""Generate a cache key from function parameters.
Args:
custom_params_encoder: Optional encoder to apply to function parameters.
If None, params are pickled directly.
*args: Positional arguments to encode.
**kwargs: Keyword arguments to encode.
Returns:
A 32-character hex string suitable for use as a cache key.
"""
if custom_params_encoder is None:
# Pickle the parameters directly
pickled_params: bytes = pickle.dumps((args, kwargs))
else:
# Encode the parameters using the custom encoder
encoded_params = custom_params_encoder(*args, **kwargs)
# Pickle the encoded output
pickled_params = pickle.dumps(encoded_params)
# Hash the pickled bytes with SHA256
hash_obj = sha256(pickled_params)
# Get hex digest and truncate to 32 characters
return hash_obj.hexdigest()[:32]
def record(
self,
custom_params_encoder: Callable[_P, object] | None = None,
@@ -178,16 +237,190 @@ class Memoizer(_BaseMemoizer):
This class provides methods for recording, retrieving, and managing
cached function results in memory with custom encoding/decoding logic.
Cache entries are stored as CacheEntry objects containing both the encoded
parameters and the encoded result. This makes debugging easier since entries
can be inspected with full context about what inputs produced each result.
On memoizer destruction, the cache is automatically dumped to a shared JSON file
in the cache directory for debugging and inspection purposes. Multiple Memoizer
instances contribute to the same file additively.
Note: Use this over functools.cache when you need to support parameters
that functools.cache cannot handle, or when you need custom encoding/decoding
of results.
"""
def __init__(self) -> None:
"""Initialize the Memoizer instance with an in-memory cache."""
def __init__(self, sub_key: str | None = None) -> None:
"""Initialize the Memoizer instance with an in-memory cache.
Args:
sub_key: Optional key for organizing cache entries in the JSON dump.
If provided, cache entries are stored under cache_entries[sub_key].
If None, cache entries are merged directly into root cache_entries.
"""
self._cache: implementations._InMemoryCacheImpl = (
implementations._InMemoryCacheImpl()
)
# Optional sub_key for nested cache structure
self._sub_key: str | None = sub_key
# Register atexit handler to dump cache on program exit
if config.IS_DUMP_MEMOIZER_CACHE_ENABLED():
atexit.register(self._dump_to_disk)
@functools.cached_property
def _shared_cache_filepath(self) -> Path:
"""Get the shared cache filepath for memoizer cache dumps.
Returns:
The path to the shared memoizer cache JSON file.
"""
return Path(cache_dir()) / "memoizer_cache.json"
@functools.cached_property
def _shared_cache_lockfile(self) -> Path:
"""Get the lock file path for the shared memoizer cache.
Returns:
The path to the lock file for the shared cache.
"""
return Path(cache_dir()) / "memoizer_cache.lock"
def _read_dump_from_disk(self) -> CacheDump | None:
"""Read the cache dump from disk.
Attempts to read and parse the shared cache JSON file.
Returns:
The cache dump if the file exists and is valid JSON, None otherwise.
"""
try:
with open(self._shared_cache_filepath) as f:
data = json.load(f)
return cast(CacheDump, data)
except FileNotFoundError:
return None
except json.JSONDecodeError:
return None
def _write_dump_to_disk(self, dump: CacheDump) -> None:
"""Write the cache dump to disk.
Writes the provided dump to the shared cache JSON file and logs the result.
Args:
dump: The cache dump to write.
"""
try:
with open(self._shared_cache_filepath, "w") as f:
json.dump(dump, f, indent=2)
# Log the filepath
if self._sub_key:
logger.log(
logging.INFO,
"Memoizer cache (sub_key=%s) dumped to: %s",
self._sub_key,
self._shared_cache_filepath,
)
else:
logger.log(
logging.INFO,
"Memoizer cache dumped to: %s",
self._shared_cache_filepath,
)
except Exception as e:
# If dumping fails, just log it and don't crash the program
logger.log(
logging.WARNING,
"Warning: Failed to dump memoizer cache: %s",
e,
)
def _prepare_dump(self, existing_dump: CacheDump | None) -> CacheDump:
"""Prepare a cache dump from the current Memoizer state.
Takes the existing dump (if any) and merges it with the current
in-memory cache entries.
Args:
existing_dump: The existing dump to merge with, or None if starting fresh.
Returns:
The prepared cache dump ready to be written to disk.
"""
# Start with existing data or empty structure
if existing_dump is not None:
dump = existing_dump
else:
dump: CacheDump = {"cache_entries": {}, "cache_size": 0}
# Ensure cache_entries exists
if "cache_entries" not in dump:
dump["cache_entries"] = {}
# Format cache entries as {"params": ..., "result": ...}
formatted_cache: dict[str, CacheDumpEntry] = {}
for key, value in self._cache._memory.items():
entry = cast(CacheEntry, value)
formatted_cache[key] = CacheDumpEntry(
params=entry.encoded_params,
result=entry.encoded_result,
)
# Merge based on sub_key
if self._sub_key:
# Store under sub_key
dump["cache_entries"][self._sub_key] = formatted_cache
else:
# Merge directly into cache_entries
dump["cache_entries"].update(formatted_cache)
# Calculate total cache size across all entries
total_size = 0
for value in dump["cache_entries"].values():
if isinstance(value, dict):
# Check if it's a CacheDumpEntry (has 'params' and 'result') or a sub_key dict
if "params" in value and "result" in value:
# Direct entry
total_size += 1
else:
# Sub_key with nested entries
total_size += len(value)
else:
total_size += 1
dump["cache_size"] = total_size
return dump
def _dump_to_disk(self) -> None:
"""Dump the in-memory cache to a shared JSON file.
This method is automatically called on program exit via atexit.
It reads any existing cache data, merges it with this instance's cache,
and writes the combined result back. Multiple Memoizer instances
contribute to the same file additively.
If self._sub_key is set and non-empty, cache entries are stored under
cache_entries[sub_key]. Otherwise, they're merged directly into cache_entries.
Cache entries are formatted as {"params": <encoded_params>, "result": <encoded_result>}
for better human readability.
The filepath where the cache was dumped is logged.
"""
# Skip if cache is empty
if not self._cache._memory:
return
# Ensure parent directory exists
self._shared_cache_filepath.parent.mkdir(parents=True, exist_ok=True)
# Acquire file lock to ensure thread/process safety
flock = FileLock(str(self._shared_cache_lockfile))
with locks._acquire_flock_with_timeout(flock):
existing_dump = self._read_dump_from_disk()
dump = self._prepare_dump(existing_dump)
self._write_dump_to_disk(dump)
def record(
self,
@@ -245,7 +478,16 @@ class Memoizer(_BaseMemoizer):
result = fn(*args, **kwargs)
# Generate cache key from parameters
cache_key = _make_key(custom_params_encoder, *args, **kwargs)
cache_key = self._make_key(custom_params_encoder, *args, **kwargs)
# Encode params for human-readable dump
if custom_params_encoder is not None:
encoded_params = custom_params_encoder(*args, **kwargs)
else:
encoded_params = {
"args": args,
"kwargs": kwargs,
}
# Encode the result if encoder is provided
if custom_result_encoder is not None:
@@ -255,8 +497,12 @@ class Memoizer(_BaseMemoizer):
else:
encoded_result = result
# Store in cache
self._cache.insert(cache_key, encoded_result)
# Store CacheEntry in cache
cache_entry = CacheEntry(
encoded_params=encoded_params,
encoded_result=encoded_result,
)
self._cache.insert(cache_key, cache_entry)
# Return the original result (not the encoded version)
return result
@@ -325,20 +571,22 @@ class Memoizer(_BaseMemoizer):
KeyError: If no cached result exists for the given parameters.
"""
# Generate cache key from parameters
cache_key = _make_key(custom_params_encoder, *args, **kwargs)
cache_key = self._make_key(custom_params_encoder, *args, **kwargs)
# Check if result is cached
cached_hit = self._cache.get(cache_key)
if cached_hit is None:
raise KeyError(f"No cached result found for key: {cache_key}")
# Extract the cached value
cache_entry = cast(CacheEntry, cached_hit.value)
# Decode and return the cached result
cached_value = cached_hit.value
if custom_result_decoder is not None:
# Get the decoder function by calling the factory with params
decoder_fn = custom_result_decoder(*args, **kwargs)
return decoder_fn(cast(_EncodedR, cached_value))
return cast(_R, cached_value)
return decoder_fn(cast(_EncodedR, cache_entry.encoded_result))
return cast(_R, cache_entry.encoded_result)
return inner
@@ -355,6 +603,11 @@ class PersistentMemoizer(_BaseMemoizer):
Results are persisted across process restarts.
On program exit, the in-memory cache entries are automatically dumped to
the shared JSON file. If sub_dir is non-empty, entries are stored under
a nested structure based on the sub_dir. If sub_dir is empty, entries are
merged directly into the root cache_entries.
Note: Use this over functools.cache when you need to support parameters
that functools.cache cannot handle, custom result encoding and/or decoding,
or when you need disk caching to persist results across program boundaries.
@@ -366,12 +619,14 @@ class PersistentMemoizer(_BaseMemoizer):
Args:
sub_dir: Optional subdirectory within the cache directory for
organizing cached results. Defaults to empty string if not specified.
If non-empty, cache entries will be stored under cache_entries[sub_dir].
If empty, cache entries are merged into root cache_entries.
"""
# Use a Memoizer instance for in-memory caching
self._memoizer: Memoizer = Memoizer()
self._memoizer: Memoizer = Memoizer(sub_key=str(sub_dir) if sub_dir else None)
# Store on-disk cache as a separate attribute
self._disk_cache: implementations._OnDiskCacheImpl = (
implementations._OnDiskCacheImpl(sub_dir)
implementations._OnDiskCacheImpl(sub_dir=sub_dir)
)
def record(
@@ -436,16 +691,17 @@ class PersistentMemoizer(_BaseMemoizer):
result = memory_record_fn(*args, **kwargs)
# Also store in disk cache
cache_key = _make_key(custom_params_encoder, *args, **kwargs)
cache_key = self._make_key(custom_params_encoder, *args, **kwargs)
# Get the encoded result from memory cache
# Get the cache entry from memory cache
# We know it must be there since memory_record_fn just cached it
cached_hit = self._memoizer._cache.get(cache_key)
encoded_result = cached_hit.value # type: ignore[union-attr]
assert cached_hit, "Cache entry must exist in memory cache"
cache_entry = cast(CacheEntry, cached_hit.value)
# Store in disk cache (requires bytes, so pickle)
pickled_result: bytes = pickle.dumps(encoded_result)
self._disk_cache.insert(cache_key, pickled_result)
# Store the full CacheEntry in disk cache for easier debugging
pickled_entry: bytes = pickle.dumps(cache_entry)
self._disk_cache.insert(cache_key, pickled_entry)
return result
@@ -529,21 +785,21 @@ class PersistentMemoizer(_BaseMemoizer):
pass # Memory miss, check disk
# Memory miss - check disk cache
cache_key = _make_key(custom_params_encoder, *args, **kwargs)
cache_key = self._make_key(custom_params_encoder, *args, **kwargs)
disk_hit = self._disk_cache.get(cache_key)
if disk_hit is not None:
# Disk cache hit - unpickle the bytes
# Disk cache hit - unpickle the CacheEntry
pickled_value = disk_hit.value
cached_value = pickle.loads(pickled_value)
cache_entry = cast(CacheEntry, pickle.loads(pickled_value))
# Populate memory cache for future access
self._memoizer._cache.insert(cache_key, cached_value)
self._memoizer._cache.insert(cache_key, cache_entry)
# Decode and return
if custom_result_decoder is not None:
decoder_fn = custom_result_decoder(*args, **kwargs)
return decoder_fn(cast(_EncodedR, cached_value))
return cast(_R, cached_value)
return decoder_fn(cast(_EncodedR, cache_entry.encoded_result))
return cast(_R, cache_entry.encoded_result)
# Complete miss
raise KeyError(f"No cached result found for key: {cache_key}")

View File

@@ -15,7 +15,7 @@ from contextlib import _GeneratorContextManager, contextmanager, ExitStack
from typing import TYPE_CHECKING, TypeAlias
from typing_extensions import Protocol
from filelock import FileLock, Timeout
from filelock import BaseFileLock, Timeout
from . import exceptions
@@ -119,7 +119,7 @@ def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None)
@contextmanager
def _acquire_flock_with_timeout(
flock: FileLock,
flock: BaseFileLock,
timeout: float | None = None,
) -> Generator[None, None, None]:
"""Context manager that safely acquires a FileLock with timeout and automatically releases it.
@@ -155,7 +155,10 @@ def _acquire_flock_with_timeout(
flock.release()
def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None:
def _unsafe_acquire_flock_with_timeout(
flock: BaseFileLock,
timeout: float | None,
) -> None:
"""Acquire a FileLock with timeout without automatic release (unsafe).
This function acquires a file lock with timeout support but does NOT automatically

View File

@@ -256,6 +256,7 @@ def set_logs(
inductor_metrics: bool = False,
hierarchical_compile: bool = False,
compute_dependencies: bool = False,
caching: bool = False,
) -> None:
"""
Sets the log level for individual components and toggles individual log
@@ -456,6 +457,9 @@ def set_logs(
hierarchical_compile (:class:`bool`):
Whether to emit debug info for hierarchical compilation. Default: ``False``
caching (:class:`bool`):
Whether to emit detailed Inductor caching information. Default: ``False``
Example::
>>> # xdoctest: +SKIP
@@ -570,6 +574,7 @@ def set_logs(
inductor_metrics=inductor_metrics,
hierarchical_compile=hierarchical_compile,
compute_dependencies=compute_dependencies,
caching=caching,
)

View File

@@ -257,3 +257,8 @@ register_artifact(
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
register_artifact(
"caching",
"Detailed Inductor caching information.",
off_by_default=True,
)