""" Python implementation of function wrapping functionality for functorch.dim. """ from __future__ import annotations import functools from typing import Any, Optional, TYPE_CHECKING import torch from torch.utils._pytree import tree_map from ._dim_entry import DimEntry from ._enable_all_layers import EnableAllLayers from ._tensor_info import TensorInfo if TYPE_CHECKING: from collections.abc import Callable def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor: """Handle tensor conversion for torch function integration.""" return tensor class WrappedOperator: """ This class wraps PyTorch operations to support first-class dimensions. """ def __init__( self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim" ): self.orig = orig self.wrapper_implementation = wrapper_implementation self.name = getattr(orig, "__name__", "") self.doc = getattr(orig, "__doc__", None) self.dim_name = dim_name self.is_pointwise = False self.dim_offset = 0 self.keepdim_offset = 1 self.single_dim = False self.reduce = True # Update docstring if we have a dim_name if self.doc and self.dim_name: self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n" def function(self) -> Callable: """Create a wrapped function that calls our wrapper implementation.""" def wrapped_func(*args: Any, **kwargs: Any) -> Any: return self.wrapper_implementation(self, *args, **kwargs) # Copy metadata using functools.update_wrapper for just __name__ and __doc__ functools.update_wrapper( wrapped_func, self.orig, assigned=("__name__",), updated=() ) wrapped_func.__doc__ = self.doc return wrapped_func def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry: """Convert single dimension specification to DimEntry object.""" from . import Dim if isinstance(dim, Dim): if keepdim: raise ValueError("cannot preserve first-class dimensions with keepdim=True") return DimEntry(dim) elif isinstance(dim, int): i = dim while i >= 0: i -= ndim return DimEntry(i) else: return DimEntry() def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]: """Convert dimension specification to list of DimEntry objects.""" de = _wrap_dim(dim, ndim, keepdim) result = [] if not de.is_none(): result.append(de) else: for d in dim: result.append(_wrap_dim(d, ndim, keepdim)) return result def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any: """ This is the core method that handles dimension-aware operations. """ if not args: raise ValueError("Expected at least one argument (self)") # Get dimension argument dim_arg = kwargs.get(wrapper.dim_name) if dim_arg is None and wrapper.dim_offset < len(args): # Try to get dim from positional args (accounting for self at index 0) dim_idx = wrapper.dim_offset + 1 if dim_idx < len(args): dim_arg = args[dim_idx] # If no dimension argument provided, fall back to standard functorch handling if dim_arg is None: info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False) if not info: return wrapper.orig(*args, **kwargs) with EnableAllLayers(info.levels) as guard: assert info.batchedtensor is not None guard.inplace_update_layers(info.batchedtensor, info.levels) new_args = list(args) new_args[0] = handle_from_tensor(info.batchedtensor) result = wrapper.orig(*new_args, **kwargs) return guard.from_batched(result, info.has_device) # Handle dimension-aware operation info = TensorInfo.create(args[0]) if not info: return wrapper.orig(*args, **kwargs) # Check for keepdim parameter keepdim = False if wrapper.reduce: keepdim_arg = kwargs.get("keepdim") if keepdim_arg is None and wrapper.keepdim_offset < len(args): keepdim_idx = wrapper.keepdim_offset + 1 if keepdim_idx < len(args): keepdim_arg = args[keepdim_idx] if keepdim_arg is not None: keepdim = bool(keepdim_arg) # Wrap dimensions ndim = info.ndim() dims = _wrap_dims(dim_arg, ndim, keepdim) # Convert dimensions to indices and validate dim_indices: list[int] = [] seen = [False] * len(info.levels) for d in dims: midx = None for i, level in enumerate(info.levels): if level == d: midx = i break if midx is None: # Try to match by position/name more flexibly for i, level in enumerate(info.levels): if hasattr(level, "matches") and level.matches(d): midx = i break if midx is None: level_strs = [str(level) for level in info.levels] raise ValueError( f"Tensor with dimensions {level_strs} does not contain {d}" ) seen[midx] = True dim_indices.append(midx) # Determine new levels after reduction new_levels = [] if wrapper.reduce and not keepdim: for i, level in enumerate(info.levels): if not seen[i]: new_levels.append(level) else: new_levels = info.levels[:] # Create dimension indices for the original function if len(dim_indices) == 1: py_indices: Any = dim_indices[0] else: py_indices = tuple(dim_indices) # Update arguments new_args = list(args) new_kwargs = kwargs.copy() assert info.tensor is not None new_args[0] = handle_from_tensor(info.tensor) # Update dimension argument if wrapper.dim_name in new_kwargs: new_kwargs[wrapper.dim_name] = py_indices else: dim_idx = wrapper.dim_offset + 1 if dim_idx < len(new_args): new_args = list(new_args) new_args[dim_idx] = py_indices # Call original function result = wrapper.orig(*new_args, **new_kwargs) # Wrap results def wrap_result(obj: Any) -> Any: if isinstance(obj, torch.Tensor): from . import Tensor return Tensor.from_positional(obj, new_levels, info.has_device) return obj return tree_map(wrap_result, result) def _wrap( orig: Callable, dim_offset: Optional[int] = None, keepdim_offset: Optional[int] = None, dim_name: Optional[str] = None, single_dim: Optional[bool] = None, reduce: Optional[bool] = None, ) -> Callable: """ Wrap a PyTorch function to support first-class dimensions. Args: orig: Original function to wrap dim_offset: Offset for dimension argument (default: 0) keepdim_offset: Offset for keepdim argument (default: 1) dim_name: Name of dimension parameter (default: "dim") single_dim: Whether function takes single dimension (default: False) reduce: Whether function reduces dimensions (default: True) """ dim_name = dim_name or "dim" wrapper = WrappedOperator(orig, patched_dim_method, dim_name) if dim_offset is not None: wrapper.dim_offset = dim_offset if keepdim_offset is not None: wrapper.keepdim_offset = keepdim_offset if single_dim is not None: wrapper.single_dim = single_dim if reduce is not None: wrapper.reduce = reduce return wrapper.function() def call_torch_function( wrapper: WrappedOperator, func: Callable, types: tuple, args: tuple = (), kwargs: Optional[dict] = None, ) -> Any: """ Handle __torch_function__ calls for wrapped operators. """ if kwargs is None: kwargs = {} # Import here to avoid circular imports from . import _Tensor # Use the torch function mechanism from _Tensor return _Tensor.__torch_function__(func, types, args, kwargs)