From 6919a909f3afba158b328b386bea7c7d3eb37ea5 Mon Sep 17 00:00:00 2001 From: Nicolas Macchioni Date: Fri, 19 Dec 2025 23:59:19 +0000 Subject: [PATCH] [inductor][caching] dump imc to disk in human readable format (#170236) Summary: tldr; added a method to dump Memoizer's memory to disk in JSON format at program exit in addition, we now include the encoded params in the cache entry (since the keys are hashes they don't provide much clues as to the origin of the entry by themself, hence why the params are now included) note: the testing for the new dump method is actually included in the subsequent diff Test Plan: ```buck test fbcode//mode/opt caffe2/test/inductor:caching``` Differential Revision: D88984165 Pull Request resolved: https://github.com/pytorch/pytorch/pull/170236 Approved by: https://github.com/aorenste --- test/dynamo/test_logging.py | 1 + test/inductor/test_caching.py | 272 ++++++++++++- torch/_inductor/runtime/caching/config.py | 11 + torch/_inductor/runtime/caching/exceptions.py | 4 +- torch/_inductor/runtime/caching/interfaces.py | 358 +++++++++++++++--- torch/_inductor/runtime/caching/locks.py | 9 +- torch/_logging/_internal.py | 5 + torch/_logging/_registrations.py | 5 + 8 files changed, 600 insertions(+), 65 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index f47311817eb..9e3953f3d6c 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -1065,6 +1065,7 @@ exclusions = { "compute_dependencies", "annotation", "node_runtime_estimation", + "caching", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index 84689432aea..7c85f0d04bb 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -2,6 +2,7 @@ # pyre-strict from __future__ import annotations +import json import os from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError, wait from contextlib import contextmanager @@ -1089,11 +1090,14 @@ class InterfacesTest(TestMixin, TestCase): self.assertEqual(result, 10) self.assertEqual(call_count, 1) - # Verify memory cache has the result - cache_key = interfaces._make_key(None, 5) + # Verify memory cache has the result as CacheEntry + cache_key = interfaces._BaseMemoizer._make_key(None, 5) memory_hit = persistent._memoizer._cache.get(cache_key) self.assertIsNotNone(memory_hit) - self.assertEqual(memory_hit.value, 10) + # Cache now stores CacheEntry with encoded_params and encoded_result + cache_entry = memory_hit.value + self.assertEqual(cache_entry.encoded_result, 10) + self.assertEqual(cache_entry.encoded_params, {"args": (5,), "kwargs": {}}) # Verify disk cache has the result (pickled) disk_hit = persistent._disk_cache.get(cache_key) @@ -1112,11 +1116,16 @@ class InterfacesTest(TestMixin, TestCase): # Setup: create a persistent memoizer and store only to disk persistent = interfaces.PersistentMemoizer(sub_dir=self.sub_dir()) - # Store a value directly to disk cache only - cache_key = interfaces._make_key(None, 5) + # Store a value directly to disk cache only (as CacheEntry) + cache_key = interfaces._BaseMemoizer._make_key(None, 5) import pickle - pickled_value = pickle.dumps(10) + # Cache now stores CacheEntry with encoded_params and encoded_result + cache_entry = interfaces.CacheEntry( + encoded_params={"args": (5,), "kwargs": {}}, + encoded_result=10, + ) + pickled_value = pickle.dumps(cache_entry) persistent._disk_cache.insert(cache_key, pickled_value) # Verify it's not in memory cache yet @@ -1137,7 +1146,9 @@ class InterfacesTest(TestMixin, TestCase): # Verify memory cache was populated from disk memory_hit = persistent._memoizer._cache.get(cache_key) self.assertIsNotNone(memory_hit) - self.assertEqual(memory_hit.value, 10) + # Memory cache should now contain the CacheEntry + cache_entry = memory_hit.value + self.assertEqual(cache_entry.encoded_result, 10) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) @@ -1170,7 +1181,7 @@ class InterfacesTest(TestMixin, TestCase): self.assertEqual(result2, 10) # Clear memory cache to simulate a new process - cache_key = interfaces._make_key(None, 5) + cache_key = interfaces._BaseMemoizer._make_key(None, 5) persistent._memoizer._cache = impls._InMemoryCacheImpl() # Third call - memory miss, disk hit, populates memory @@ -1210,7 +1221,7 @@ class InterfacesTest(TestMixin, TestCase): self.assertEqual(result2, 10) # Verify nothing was cached - cache_key = interfaces._make_key(None, 5) + cache_key = interfaces._BaseMemoizer._make_key(None, 5) memory_hit = persistent._memoizer._cache.get(cache_key) self.assertIsNone(memory_hit) disk_hit = persistent._disk_cache.get(cache_key) @@ -1263,6 +1274,249 @@ class InterfacesTest(TestMixin, TestCase): self.assertEqual(result1, 10) self.assertEqual(result2, 10) + # ============= Memoizer._dump_to_disk Tests ============= + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_creates_json_file(self) -> None: + """Test that _dump_to_disk creates a JSON file with cached entries. + + Verifies that when _dump_to_disk is called, it creates a JSON file + containing the cached entries in a human-readable format. + """ + # Setup: create a memoizer and cache some values + memoizer = interfaces.Memoizer() + + @memoizer.record() + def compute(x: int) -> int: + return x * 2 + + compute(5) + compute(10) + + # Execute: dump the cache to disk + memoizer._dump_to_disk() + + # Assert: JSON file was created with correct structure + self.assertTrue(memoizer._shared_cache_filepath.exists()) + + with open(memoizer._shared_cache_filepath) as f: + data = json.load(f) + + self.assertIn("cache_entries", data) + self.assertIn("cache_size", data) + self.assertEqual(data["cache_size"], 2) + + # Verify entries have correct format + for entry in data["cache_entries"].values(): + self.assertIn("params", entry) + self.assertIn("result", entry) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_with_sub_key(self) -> None: + """Test that _dump_to_disk uses sub_key for nested structure. + + Verifies that when a Memoizer is initialized with a sub_key, + the cache entries are stored under cache_entries[sub_key]. + """ + # Setup: create a memoizer with sub_key and cache a value + sub_key = "test_sub_key" + memoizer = interfaces.Memoizer(sub_key=sub_key) + + @memoizer.record() + def compute(x: int) -> int: + return x * 2 + + compute(5) + + # Execute: dump the cache to disk + memoizer._dump_to_disk() + + # Assert: entries are stored under the sub_key + with open(memoizer._shared_cache_filepath) as f: + data = json.load(f) + + self.assertIn("cache_entries", data) + self.assertIn(sub_key, data["cache_entries"]) + + # The sub_key should contain the cache entries + sub_entries = data["cache_entries"][sub_key] + self.assertEqual(len(sub_entries), 1) + + # Verify entry format + for entry in sub_entries.values(): + self.assertIn("params", entry) + self.assertIn("result", entry) + self.assertEqual(entry["result"], 10) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_merges_with_existing(self) -> None: + """Test that _dump_to_disk merges with existing cache data. + + Verifies that when multiple Memoizer instances dump to the same file, + their entries are merged additively. + """ + # Setup: create first memoizer and cache a value + memoizer1 = interfaces.Memoizer() + + @memoizer1.record() + def compute1(x: int) -> int: + return x * 2 + + compute1(5) + memoizer1._dump_to_disk() + + # Create second memoizer and cache a different value + memoizer2 = interfaces.Memoizer() + + @memoizer2.record() + def compute2(x: int) -> int: + return x * 3 + + compute2(10) + + # Execute: dump second memoizer to disk + memoizer2._dump_to_disk() + + # Assert: both entries are in the file + with open(memoizer1._shared_cache_filepath) as f: + data = json.load(f) + + self.assertEqual(data["cache_size"], 2) + self.assertEqual(len(data["cache_entries"]), 2) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_skips_empty_cache(self) -> None: + """Test that _dump_to_disk does nothing when cache is empty. + + Verifies that when _dump_to_disk is called on an empty cache, + no file is created. + """ + # Setup: create a memoizer with no cached values + memoizer = interfaces.Memoizer() + + # Ensure the file doesn't exist beforehand + if memoizer._shared_cache_filepath.exists(): + memoizer._shared_cache_filepath.unlink() + + # Execute: dump the empty cache + memoizer._dump_to_disk() + + # Assert: no file was created + self.assertFalse(memoizer._shared_cache_filepath.exists()) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_handles_corrupt_file(self) -> None: + """Test that _dump_to_disk handles corrupt JSON files gracefully. + + Verifies that when the existing cache file contains invalid JSON, + _dump_to_disk starts fresh and overwrites the corrupt file. + """ + # Setup: create a memoizer and cache a value + memoizer = interfaces.Memoizer() + + @memoizer.record() + def compute(x: int) -> int: + return x * 2 + + compute(5) + + # Create a corrupt JSON file + memoizer._shared_cache_filepath.parent.mkdir(parents=True, exist_ok=True) + with open(memoizer._shared_cache_filepath, "w") as f: + f.write("{ invalid json content") + + # Execute: dump the cache (should handle corrupt file) + memoizer._dump_to_disk() + + # Assert: file now contains valid JSON with our entry + with open(memoizer._shared_cache_filepath) as f: + data = json.load(f) + + self.assertIn("cache_entries", data) + self.assertEqual(data["cache_size"], 1) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_stores_encoded_params_and_result(self) -> None: + """Test that _dump_to_disk stores both encoded params and result. + + Verifies that the dumped JSON contains both the encoded parameters + and the encoded result for each cache entry, making it useful for debugging. + """ + # Setup: create a memoizer with custom encoder and cache a value + memoizer = interfaces.Memoizer() + + @memoizer.record() + def compute(x: int, y: int) -> int: + return x + y + + compute(5, 10) + + # Execute: dump the cache to disk + memoizer._dump_to_disk() + + # Assert: entry contains both params and result + with open(memoizer._shared_cache_filepath) as f: + data = json.load(f) + + # Get the single entry + entries = data["cache_entries"] + self.assertEqual(len(entries), 1) + + entry = next(iter(entries.values())) + self.assertEqual(entry["result"], 15) + self.assertEqual(entry["params"]["args"], [5, 10]) + self.assertEqual(entry["params"]["kwargs"], {}) + + @patch_on_disk_cache_base_dir + @set_caching_module_enabled(True) + def test_memoizer_dump_to_disk_multiple_sub_keys(self) -> None: + """Test that multiple Memoizers with different sub_keys coexist. + + Verifies that Memoizers with different sub_keys store their entries + under separate namespaces in the same JSON file. + """ + # Setup: create two memoizers with different sub_keys + memoizer1 = interfaces.Memoizer(sub_key="feature_a") + memoizer2 = interfaces.Memoizer(sub_key="feature_b") + + @memoizer1.record() + def compute_a(x: int) -> int: + return x * 2 + + @memoizer2.record() + def compute_b(x: int) -> int: + return x * 3 + + compute_a(5) + compute_b(10) + + # Execute: dump both caches + memoizer1._dump_to_disk() + memoizer2._dump_to_disk() + + # Assert: both sub_keys exist with their respective entries + with open(memoizer1._shared_cache_filepath) as f: + data = json.load(f) + + self.assertIn("feature_a", data["cache_entries"]) + self.assertIn("feature_b", data["cache_entries"]) + + # Verify each sub_key has one entry with correct result + feature_a_entries = data["cache_entries"]["feature_a"] + feature_b_entries = data["cache_entries"]["feature_b"] + + self.assertEqual(len(feature_a_entries), 1) + self.assertEqual(len(feature_b_entries), 1) + + self.assertEqual(next(iter(feature_a_entries.values()))["result"], 10) + self.assertEqual(next(iter(feature_b_entries.values()))["result"], 30) + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index d36a77e5b74..594c6bdd7b1 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -70,3 +70,14 @@ IS_CACHING_MODULE_ENABLED: Callable[[], bool] = partial( _CACHING_MODULE_OSS_DEFAULT, _CACHING_MODULE_ENV_VAR_OVERRIDE, ) + + +# Controls whether the Memoizer dumps its cache to a JSON file on destruction. +# This is useful for debugging and inspection purposes. +_DUMP_MEMOIZER_CACHE_ENV_VAR: str = "TORCHINDUCTOR_DUMP_MEMOIZER_CACHE" +_DUMP_MEMOIZER_CACHE_DEFAULT: bool = False +IS_DUMP_MEMOIZER_CACHE_ENABLED: Callable[[], bool] = partial( + _env_var_config, + _DUMP_MEMOIZER_CACHE_ENV_VAR, + _DUMP_MEMOIZER_CACHE_DEFAULT, +) diff --git a/torch/_inductor/runtime/caching/exceptions.py b/torch/_inductor/runtime/caching/exceptions.py index 096554c1996..8d49b46035d 100644 --- a/torch/_inductor/runtime/caching/exceptions.py +++ b/torch/_inductor/runtime/caching/exceptions.py @@ -9,7 +9,7 @@ for user-facing errors that also inherit from TypeError for compatibility. from threading import Lock -from filelock import FileLock +from filelock import BaseFileLock class CacheError(Exception): @@ -53,7 +53,7 @@ class FileLockTimeoutError(SystemError): limit, indicating that the lock could not be acquired within the allotted time. """ - def __init__(self, flock: FileLock, timeout: float) -> None: + def __init__(self, flock: BaseFileLock, timeout: float) -> None: """Initialize the file lock timeout error with detailed lock information. Args: diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 14800006c21..7949748cef5 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -6,16 +6,28 @@ This module provides high-level caching interfaces for memoization and result caching functionality. """ +import atexit import functools +import json +import logging import pickle from collections.abc import Callable +from dataclasses import dataclass from hashlib import sha256 from os import PathLike -from typing import cast +from pathlib import Path +from typing import cast, TypedDict from typing_extensions import ParamSpec, TypeVar -from . import config, implementations +from filelock import FileLock +import torch +from torch._inductor.runtime.runtime_utils import cache_dir + +from . import config, implementations, locks + + +logger = torch._logging.getArtifactLogger(__name__, "caching") # Type variable for function parameters _P = ParamSpec("_P") @@ -25,36 +37,51 @@ _R = TypeVar("_R") _EncodedR = TypeVar("_EncodedR") -def _make_key( - custom_params_encoder: Callable[..., object] | None, - *args: object, - **kwargs: object, -) -> str: - """Generate a cache key from function parameters. +class CacheDumpEntry(TypedDict): + """A single cache entry in the dump format. - Args: - custom_params_encoder: Optional encoder to apply to function parameters. - If None, params are pickled directly. - *args: Positional arguments to encode. - **kwargs: Keyword arguments to encode. - - Returns: - A 32-character hex string suitable for use as a cache key. + Attributes: + params: The encoded function parameters. + result: The encoded function result. """ - if custom_params_encoder is None: - # Pickle the parameters directly - pickled_params: bytes = pickle.dumps((args, kwargs)) - else: - # Encode the parameters using the custom encoder - encoded_params = custom_params_encoder(*args, **kwargs) - # Pickle the encoded output - pickled_params = pickle.dumps(encoded_params) - # Hash the pickled bytes with SHA256 - hash_obj = sha256(pickled_params) + params: object + result: object - # Get hex digest and truncate to 32 characters - return hash_obj.hexdigest()[:32] + +class CacheDump(TypedDict): + """The structure of the memoizer cache dump file. + + The cache_entries field contains either: + - Direct entries: {cache_key: CacheDumpEntry} when no sub_key is used + - Nested entries: {sub_key: {cache_key: CacheDumpEntry}} when sub_key is set + + Multiple Memoizer instances with different sub_keys can coexist in the same file. + + Attributes: + cache_entries: Dictionary mapping cache keys (or sub_keys) to cache entries. + cache_size: The total number of cache entries. When sub_keys are used, + this is the sum of entries across all sub_keys. + """ + + cache_entries: dict[str, CacheDumpEntry | dict[str, CacheDumpEntry]] + cache_size: int + + +@dataclass +class CacheEntry: + """A cache entry containing encoded parameters and result. + + This dataclass stores the encoded form of function parameters and + the (possibly encoded) result for human-readable cache dumps. + + Attributes: + encoded_params: The encoded function parameters used for debugging/inspection. + encoded_result: The (possibly encoded) result of the function call. + """ + + encoded_params: object + encoded_result: object class _BaseMemoizer: @@ -64,6 +91,38 @@ class _BaseMemoizer: record and replay functionality. """ + @staticmethod + def _make_key( + custom_params_encoder: Callable[..., object] | None, + *args: object, + **kwargs: object, + ) -> str: + """Generate a cache key from function parameters. + + Args: + custom_params_encoder: Optional encoder to apply to function parameters. + If None, params are pickled directly. + *args: Positional arguments to encode. + **kwargs: Keyword arguments to encode. + + Returns: + A 32-character hex string suitable for use as a cache key. + """ + if custom_params_encoder is None: + # Pickle the parameters directly + pickled_params: bytes = pickle.dumps((args, kwargs)) + else: + # Encode the parameters using the custom encoder + encoded_params = custom_params_encoder(*args, **kwargs) + # Pickle the encoded output + pickled_params = pickle.dumps(encoded_params) + + # Hash the pickled bytes with SHA256 + hash_obj = sha256(pickled_params) + + # Get hex digest and truncate to 32 characters + return hash_obj.hexdigest()[:32] + def record( self, custom_params_encoder: Callable[_P, object] | None = None, @@ -178,16 +237,190 @@ class Memoizer(_BaseMemoizer): This class provides methods for recording, retrieving, and managing cached function results in memory with custom encoding/decoding logic. + Cache entries are stored as CacheEntry objects containing both the encoded + parameters and the encoded result. This makes debugging easier since entries + can be inspected with full context about what inputs produced each result. + + On memoizer destruction, the cache is automatically dumped to a shared JSON file + in the cache directory for debugging and inspection purposes. Multiple Memoizer + instances contribute to the same file additively. + Note: Use this over functools.cache when you need to support parameters that functools.cache cannot handle, or when you need custom encoding/decoding of results. """ - def __init__(self) -> None: - """Initialize the Memoizer instance with an in-memory cache.""" + def __init__(self, sub_key: str | None = None) -> None: + """Initialize the Memoizer instance with an in-memory cache. + + Args: + sub_key: Optional key for organizing cache entries in the JSON dump. + If provided, cache entries are stored under cache_entries[sub_key]. + If None, cache entries are merged directly into root cache_entries. + """ self._cache: implementations._InMemoryCacheImpl = ( implementations._InMemoryCacheImpl() ) + # Optional sub_key for nested cache structure + self._sub_key: str | None = sub_key + # Register atexit handler to dump cache on program exit + if config.IS_DUMP_MEMOIZER_CACHE_ENABLED(): + atexit.register(self._dump_to_disk) + + @functools.cached_property + def _shared_cache_filepath(self) -> Path: + """Get the shared cache filepath for memoizer cache dumps. + + Returns: + The path to the shared memoizer cache JSON file. + """ + return Path(cache_dir()) / "memoizer_cache.json" + + @functools.cached_property + def _shared_cache_lockfile(self) -> Path: + """Get the lock file path for the shared memoizer cache. + + Returns: + The path to the lock file for the shared cache. + """ + return Path(cache_dir()) / "memoizer_cache.lock" + + def _read_dump_from_disk(self) -> CacheDump | None: + """Read the cache dump from disk. + + Attempts to read and parse the shared cache JSON file. + + Returns: + The cache dump if the file exists and is valid JSON, None otherwise. + """ + try: + with open(self._shared_cache_filepath) as f: + data = json.load(f) + return cast(CacheDump, data) + except FileNotFoundError: + return None + except json.JSONDecodeError: + return None + + def _write_dump_to_disk(self, dump: CacheDump) -> None: + """Write the cache dump to disk. + + Writes the provided dump to the shared cache JSON file and logs the result. + + Args: + dump: The cache dump to write. + """ + try: + with open(self._shared_cache_filepath, "w") as f: + json.dump(dump, f, indent=2) + + # Log the filepath + if self._sub_key: + logger.log( + logging.INFO, + "Memoizer cache (sub_key=%s) dumped to: %s", + self._sub_key, + self._shared_cache_filepath, + ) + else: + logger.log( + logging.INFO, + "Memoizer cache dumped to: %s", + self._shared_cache_filepath, + ) + except Exception as e: + # If dumping fails, just log it and don't crash the program + logger.log( + logging.WARNING, + "Warning: Failed to dump memoizer cache: %s", + e, + ) + + def _prepare_dump(self, existing_dump: CacheDump | None) -> CacheDump: + """Prepare a cache dump from the current Memoizer state. + + Takes the existing dump (if any) and merges it with the current + in-memory cache entries. + + Args: + existing_dump: The existing dump to merge with, or None if starting fresh. + + Returns: + The prepared cache dump ready to be written to disk. + """ + # Start with existing data or empty structure + if existing_dump is not None: + dump = existing_dump + else: + dump: CacheDump = {"cache_entries": {}, "cache_size": 0} + + # Ensure cache_entries exists + if "cache_entries" not in dump: + dump["cache_entries"] = {} + + # Format cache entries as {"params": ..., "result": ...} + formatted_cache: dict[str, CacheDumpEntry] = {} + for key, value in self._cache._memory.items(): + entry = cast(CacheEntry, value) + formatted_cache[key] = CacheDumpEntry( + params=entry.encoded_params, + result=entry.encoded_result, + ) + + # Merge based on sub_key + if self._sub_key: + # Store under sub_key + dump["cache_entries"][self._sub_key] = formatted_cache + else: + # Merge directly into cache_entries + dump["cache_entries"].update(formatted_cache) + + # Calculate total cache size across all entries + total_size = 0 + for value in dump["cache_entries"].values(): + if isinstance(value, dict): + # Check if it's a CacheDumpEntry (has 'params' and 'result') or a sub_key dict + if "params" in value and "result" in value: + # Direct entry + total_size += 1 + else: + # Sub_key with nested entries + total_size += len(value) + else: + total_size += 1 + dump["cache_size"] = total_size + + return dump + + def _dump_to_disk(self) -> None: + """Dump the in-memory cache to a shared JSON file. + + This method is automatically called on program exit via atexit. + It reads any existing cache data, merges it with this instance's cache, + and writes the combined result back. Multiple Memoizer instances + contribute to the same file additively. + + If self._sub_key is set and non-empty, cache entries are stored under + cache_entries[sub_key]. Otherwise, they're merged directly into cache_entries. + + Cache entries are formatted as {"params": , "result": } + for better human readability. + + The filepath where the cache was dumped is logged. + """ + # Skip if cache is empty + if not self._cache._memory: + return + + # Ensure parent directory exists + self._shared_cache_filepath.parent.mkdir(parents=True, exist_ok=True) + + # Acquire file lock to ensure thread/process safety + flock = FileLock(str(self._shared_cache_lockfile)) + with locks._acquire_flock_with_timeout(flock): + existing_dump = self._read_dump_from_disk() + dump = self._prepare_dump(existing_dump) + self._write_dump_to_disk(dump) def record( self, @@ -245,7 +478,16 @@ class Memoizer(_BaseMemoizer): result = fn(*args, **kwargs) # Generate cache key from parameters - cache_key = _make_key(custom_params_encoder, *args, **kwargs) + cache_key = self._make_key(custom_params_encoder, *args, **kwargs) + + # Encode params for human-readable dump + if custom_params_encoder is not None: + encoded_params = custom_params_encoder(*args, **kwargs) + else: + encoded_params = { + "args": args, + "kwargs": kwargs, + } # Encode the result if encoder is provided if custom_result_encoder is not None: @@ -255,8 +497,12 @@ class Memoizer(_BaseMemoizer): else: encoded_result = result - # Store in cache - self._cache.insert(cache_key, encoded_result) + # Store CacheEntry in cache + cache_entry = CacheEntry( + encoded_params=encoded_params, + encoded_result=encoded_result, + ) + self._cache.insert(cache_key, cache_entry) # Return the original result (not the encoded version) return result @@ -325,20 +571,22 @@ class Memoizer(_BaseMemoizer): KeyError: If no cached result exists for the given parameters. """ # Generate cache key from parameters - cache_key = _make_key(custom_params_encoder, *args, **kwargs) + cache_key = self._make_key(custom_params_encoder, *args, **kwargs) # Check if result is cached cached_hit = self._cache.get(cache_key) if cached_hit is None: raise KeyError(f"No cached result found for key: {cache_key}") + # Extract the cached value + cache_entry = cast(CacheEntry, cached_hit.value) + # Decode and return the cached result - cached_value = cached_hit.value if custom_result_decoder is not None: # Get the decoder function by calling the factory with params decoder_fn = custom_result_decoder(*args, **kwargs) - return decoder_fn(cast(_EncodedR, cached_value)) - return cast(_R, cached_value) + return decoder_fn(cast(_EncodedR, cache_entry.encoded_result)) + return cast(_R, cache_entry.encoded_result) return inner @@ -355,6 +603,11 @@ class PersistentMemoizer(_BaseMemoizer): Results are persisted across process restarts. + On program exit, the in-memory cache entries are automatically dumped to + the shared JSON file. If sub_dir is non-empty, entries are stored under + a nested structure based on the sub_dir. If sub_dir is empty, entries are + merged directly into the root cache_entries. + Note: Use this over functools.cache when you need to support parameters that functools.cache cannot handle, custom result encoding and/or decoding, or when you need disk caching to persist results across program boundaries. @@ -366,12 +619,14 @@ class PersistentMemoizer(_BaseMemoizer): Args: sub_dir: Optional subdirectory within the cache directory for organizing cached results. Defaults to empty string if not specified. + If non-empty, cache entries will be stored under cache_entries[sub_dir]. + If empty, cache entries are merged into root cache_entries. """ # Use a Memoizer instance for in-memory caching - self._memoizer: Memoizer = Memoizer() + self._memoizer: Memoizer = Memoizer(sub_key=str(sub_dir) if sub_dir else None) # Store on-disk cache as a separate attribute self._disk_cache: implementations._OnDiskCacheImpl = ( - implementations._OnDiskCacheImpl(sub_dir) + implementations._OnDiskCacheImpl(sub_dir=sub_dir) ) def record( @@ -436,16 +691,17 @@ class PersistentMemoizer(_BaseMemoizer): result = memory_record_fn(*args, **kwargs) # Also store in disk cache - cache_key = _make_key(custom_params_encoder, *args, **kwargs) + cache_key = self._make_key(custom_params_encoder, *args, **kwargs) - # Get the encoded result from memory cache + # Get the cache entry from memory cache # We know it must be there since memory_record_fn just cached it cached_hit = self._memoizer._cache.get(cache_key) - encoded_result = cached_hit.value # type: ignore[union-attr] + assert cached_hit, "Cache entry must exist in memory cache" + cache_entry = cast(CacheEntry, cached_hit.value) - # Store in disk cache (requires bytes, so pickle) - pickled_result: bytes = pickle.dumps(encoded_result) - self._disk_cache.insert(cache_key, pickled_result) + # Store the full CacheEntry in disk cache for easier debugging + pickled_entry: bytes = pickle.dumps(cache_entry) + self._disk_cache.insert(cache_key, pickled_entry) return result @@ -529,21 +785,21 @@ class PersistentMemoizer(_BaseMemoizer): pass # Memory miss, check disk # Memory miss - check disk cache - cache_key = _make_key(custom_params_encoder, *args, **kwargs) + cache_key = self._make_key(custom_params_encoder, *args, **kwargs) disk_hit = self._disk_cache.get(cache_key) if disk_hit is not None: - # Disk cache hit - unpickle the bytes + # Disk cache hit - unpickle the CacheEntry pickled_value = disk_hit.value - cached_value = pickle.loads(pickled_value) + cache_entry = cast(CacheEntry, pickle.loads(pickled_value)) # Populate memory cache for future access - self._memoizer._cache.insert(cache_key, cached_value) + self._memoizer._cache.insert(cache_key, cache_entry) # Decode and return if custom_result_decoder is not None: decoder_fn = custom_result_decoder(*args, **kwargs) - return decoder_fn(cast(_EncodedR, cached_value)) - return cast(_R, cached_value) + return decoder_fn(cast(_EncodedR, cache_entry.encoded_result)) + return cast(_R, cache_entry.encoded_result) # Complete miss raise KeyError(f"No cached result found for key: {cache_key}") diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index 6cdab37afbc..1d276780a20 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -15,7 +15,7 @@ from contextlib import _GeneratorContextManager, contextmanager, ExitStack from typing import TYPE_CHECKING, TypeAlias from typing_extensions import Protocol -from filelock import FileLock, Timeout +from filelock import BaseFileLock, Timeout from . import exceptions @@ -119,7 +119,7 @@ def _unsafe_acquire_lock_with_timeout(lock: Lock, timeout: float | None = None) @contextmanager def _acquire_flock_with_timeout( - flock: FileLock, + flock: BaseFileLock, timeout: float | None = None, ) -> Generator[None, None, None]: """Context manager that safely acquires a FileLock with timeout and automatically releases it. @@ -155,7 +155,10 @@ def _acquire_flock_with_timeout( flock.release() -def _unsafe_acquire_flock_with_timeout(flock: FileLock, timeout: float | None) -> None: +def _unsafe_acquire_flock_with_timeout( + flock: BaseFileLock, + timeout: float | None, +) -> None: """Acquire a FileLock with timeout without automatic release (unsafe). This function acquires a file lock with timeout support but does NOT automatically diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index c2ad4924995..ac8c080ac7d 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -256,6 +256,7 @@ def set_logs( inductor_metrics: bool = False, hierarchical_compile: bool = False, compute_dependencies: bool = False, + caching: bool = False, ) -> None: """ Sets the log level for individual components and toggles individual log @@ -456,6 +457,9 @@ def set_logs( hierarchical_compile (:class:`bool`): Whether to emit debug info for hierarchical compilation. Default: ``False`` + caching (:class:`bool`): + Whether to emit detailed Inductor caching information. Default: ``False`` + Example:: >>> # xdoctest: +SKIP @@ -570,6 +574,7 @@ def set_logs( inductor_metrics=inductor_metrics, hierarchical_compile=hierarchical_compile, compute_dependencies=compute_dependencies, + caching=caching, ) diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index f0077f0f9bb..a64d09aec84 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -257,3 +257,8 @@ register_artifact( off_by_default=True, ) register_artifact("custom_format_test_artifact", "Testing only", log_format="") +register_artifact( + "caching", + "Detailed Inductor caching information.", + off_by_default=True, +)