torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.
Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.
Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.
There are two major feature gaps in the implementation:
- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
|
|
|
"""
|
|
|
|
|
Python implementation of function wrapping functionality for functorch.dim.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import functools
|
2025-10-17 00:09:49 +00:00
|
|
|
from typing import Any, Optional, TYPE_CHECKING
|
torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.
Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.
Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.
There are two major feature gaps in the implementation:
- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
|
|
from ._dim_entry import DimEntry
|
|
|
|
|
from ._enable_all_layers import EnableAllLayers
|
|
|
|
|
from ._tensor_info import TensorInfo
|
|
|
|
|
|
|
|
|
|
|
2025-10-17 00:09:49 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from collections.abc import Callable
|
|
|
|
|
|
|
|
|
|
|
torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.
Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.
Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.
There are two major feature gaps in the implementation:
- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
|
|
|
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Handle tensor conversion for torch function integration."""
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedOperator:
|
|
|
|
|
"""
|
|
|
|
|
This class wraps PyTorch operations to support first-class dimensions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
|
|
|
|
|
):
|
|
|
|
|
self.orig = orig
|
|
|
|
|
self.wrapper_implementation = wrapper_implementation
|
|
|
|
|
self.name = getattr(orig, "__name__", "")
|
|
|
|
|
self.doc = getattr(orig, "__doc__", None)
|
|
|
|
|
self.dim_name = dim_name
|
|
|
|
|
|
|
|
|
|
self.is_pointwise = False
|
|
|
|
|
self.dim_offset = 0
|
|
|
|
|
self.keepdim_offset = 1
|
|
|
|
|
self.single_dim = False
|
|
|
|
|
self.reduce = True
|
|
|
|
|
|
|
|
|
|
# Update docstring if we have a dim_name
|
|
|
|
|
if self.doc and self.dim_name:
|
|
|
|
|
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
|
|
|
|
|
|
|
|
|
|
def function(self) -> Callable:
|
|
|
|
|
"""Create a wrapped function that calls our wrapper implementation."""
|
|
|
|
|
|
|
|
|
|
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
return self.wrapper_implementation(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
|
|
|
|
|
functools.update_wrapper(
|
|
|
|
|
wrapped_func, self.orig, assigned=("__name__",), updated=()
|
|
|
|
|
)
|
|
|
|
|
wrapped_func.__doc__ = self.doc
|
|
|
|
|
|
|
|
|
|
return wrapped_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
|
|
|
|
|
"""Convert single dimension specification to DimEntry object."""
|
|
|
|
|
from . import Dim
|
|
|
|
|
|
|
|
|
|
if isinstance(dim, Dim):
|
|
|
|
|
if keepdim:
|
|
|
|
|
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
|
|
|
|
|
return DimEntry(dim)
|
|
|
|
|
elif isinstance(dim, int):
|
|
|
|
|
i = dim
|
|
|
|
|
while i >= 0:
|
|
|
|
|
i -= ndim
|
|
|
|
|
return DimEntry(i)
|
|
|
|
|
else:
|
|
|
|
|
return DimEntry()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
|
|
|
|
|
"""Convert dimension specification to list of DimEntry objects."""
|
|
|
|
|
de = _wrap_dim(dim, ndim, keepdim)
|
|
|
|
|
result = []
|
|
|
|
|
if not de.is_none():
|
|
|
|
|
result.append(de)
|
|
|
|
|
else:
|
|
|
|
|
for d in dim:
|
|
|
|
|
result.append(_wrap_dim(d, ndim, keepdim))
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
"""
|
|
|
|
|
This is the core method that handles dimension-aware operations.
|
|
|
|
|
"""
|
|
|
|
|
if not args:
|
|
|
|
|
raise ValueError("Expected at least one argument (self)")
|
|
|
|
|
|
|
|
|
|
# Get dimension argument
|
|
|
|
|
dim_arg = kwargs.get(wrapper.dim_name)
|
|
|
|
|
if dim_arg is None and wrapper.dim_offset < len(args):
|
|
|
|
|
# Try to get dim from positional args (accounting for self at index 0)
|
|
|
|
|
dim_idx = wrapper.dim_offset + 1
|
|
|
|
|
if dim_idx < len(args):
|
|
|
|
|
dim_arg = args[dim_idx]
|
|
|
|
|
|
|
|
|
|
# If no dimension argument provided, fall back to standard functorch handling
|
|
|
|
|
if dim_arg is None:
|
|
|
|
|
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
|
|
|
|
|
if not info:
|
|
|
|
|
return wrapper.orig(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
with EnableAllLayers(info.levels) as guard:
|
|
|
|
|
assert info.batchedtensor is not None
|
|
|
|
|
guard.inplace_update_layers(info.batchedtensor, info.levels)
|
|
|
|
|
new_args = list(args)
|
|
|
|
|
new_args[0] = handle_from_tensor(info.batchedtensor)
|
|
|
|
|
result = wrapper.orig(*new_args, **kwargs)
|
|
|
|
|
return guard.from_batched(result, info.has_device)
|
|
|
|
|
|
|
|
|
|
# Handle dimension-aware operation
|
|
|
|
|
info = TensorInfo.create(args[0])
|
|
|
|
|
if not info:
|
|
|
|
|
return wrapper.orig(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# Check for keepdim parameter
|
|
|
|
|
keepdim = False
|
|
|
|
|
if wrapper.reduce:
|
|
|
|
|
keepdim_arg = kwargs.get("keepdim")
|
|
|
|
|
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
|
|
|
|
|
keepdim_idx = wrapper.keepdim_offset + 1
|
|
|
|
|
if keepdim_idx < len(args):
|
|
|
|
|
keepdim_arg = args[keepdim_idx]
|
|
|
|
|
if keepdim_arg is not None:
|
|
|
|
|
keepdim = bool(keepdim_arg)
|
|
|
|
|
|
|
|
|
|
# Wrap dimensions
|
|
|
|
|
ndim = info.ndim()
|
|
|
|
|
dims = _wrap_dims(dim_arg, ndim, keepdim)
|
|
|
|
|
|
|
|
|
|
# Convert dimensions to indices and validate
|
|
|
|
|
dim_indices: list[int] = []
|
|
|
|
|
seen = [False] * len(info.levels)
|
|
|
|
|
|
|
|
|
|
for d in dims:
|
|
|
|
|
midx = None
|
|
|
|
|
for i, level in enumerate(info.levels):
|
|
|
|
|
if level == d:
|
|
|
|
|
midx = i
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if midx is None:
|
|
|
|
|
# Try to match by position/name more flexibly
|
|
|
|
|
for i, level in enumerate(info.levels):
|
|
|
|
|
if hasattr(level, "matches") and level.matches(d):
|
|
|
|
|
midx = i
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if midx is None:
|
|
|
|
|
level_strs = [str(level) for level in info.levels]
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Tensor with dimensions {level_strs} does not contain {d}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
seen[midx] = True
|
|
|
|
|
dim_indices.append(midx)
|
|
|
|
|
|
|
|
|
|
# Determine new levels after reduction
|
|
|
|
|
new_levels = []
|
|
|
|
|
if wrapper.reduce and not keepdim:
|
|
|
|
|
for i, level in enumerate(info.levels):
|
|
|
|
|
if not seen[i]:
|
|
|
|
|
new_levels.append(level)
|
|
|
|
|
else:
|
|
|
|
|
new_levels = info.levels[:]
|
|
|
|
|
|
|
|
|
|
# Create dimension indices for the original function
|
|
|
|
|
if len(dim_indices) == 1:
|
|
|
|
|
py_indices: Any = dim_indices[0]
|
|
|
|
|
else:
|
|
|
|
|
py_indices = tuple(dim_indices)
|
|
|
|
|
|
|
|
|
|
# Update arguments
|
|
|
|
|
new_args = list(args)
|
|
|
|
|
new_kwargs = kwargs.copy()
|
|
|
|
|
assert info.tensor is not None
|
|
|
|
|
new_args[0] = handle_from_tensor(info.tensor)
|
|
|
|
|
|
|
|
|
|
# Update dimension argument
|
|
|
|
|
if wrapper.dim_name in new_kwargs:
|
|
|
|
|
new_kwargs[wrapper.dim_name] = py_indices
|
|
|
|
|
else:
|
|
|
|
|
dim_idx = wrapper.dim_offset + 1
|
|
|
|
|
if dim_idx < len(new_args):
|
|
|
|
|
new_args = list(new_args)
|
|
|
|
|
new_args[dim_idx] = py_indices
|
|
|
|
|
|
|
|
|
|
# Call original function
|
|
|
|
|
result = wrapper.orig(*new_args, **new_kwargs)
|
|
|
|
|
|
|
|
|
|
# Wrap results
|
|
|
|
|
def wrap_result(obj: Any) -> Any:
|
|
|
|
|
if isinstance(obj, torch.Tensor):
|
|
|
|
|
from . import Tensor
|
|
|
|
|
|
|
|
|
|
return Tensor.from_positional(obj, new_levels, info.has_device)
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
return tree_map(wrap_result, result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap(
|
|
|
|
|
orig: Callable,
|
|
|
|
|
dim_offset: Optional[int] = None,
|
|
|
|
|
keepdim_offset: Optional[int] = None,
|
|
|
|
|
dim_name: Optional[str] = None,
|
|
|
|
|
single_dim: Optional[bool] = None,
|
|
|
|
|
reduce: Optional[bool] = None,
|
|
|
|
|
) -> Callable:
|
|
|
|
|
"""
|
|
|
|
|
Wrap a PyTorch function to support first-class dimensions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
orig: Original function to wrap
|
|
|
|
|
dim_offset: Offset for dimension argument (default: 0)
|
|
|
|
|
keepdim_offset: Offset for keepdim argument (default: 1)
|
|
|
|
|
dim_name: Name of dimension parameter (default: "dim")
|
|
|
|
|
single_dim: Whether function takes single dimension (default: False)
|
|
|
|
|
reduce: Whether function reduces dimensions (default: True)
|
|
|
|
|
"""
|
|
|
|
|
dim_name = dim_name or "dim"
|
|
|
|
|
|
|
|
|
|
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
|
|
|
|
|
|
|
|
|
|
if dim_offset is not None:
|
|
|
|
|
wrapper.dim_offset = dim_offset
|
|
|
|
|
if keepdim_offset is not None:
|
|
|
|
|
wrapper.keepdim_offset = keepdim_offset
|
|
|
|
|
if single_dim is not None:
|
|
|
|
|
wrapper.single_dim = single_dim
|
|
|
|
|
if reduce is not None:
|
|
|
|
|
wrapper.reduce = reduce
|
|
|
|
|
|
|
|
|
|
return wrapper.function()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_torch_function(
|
|
|
|
|
wrapper: WrappedOperator,
|
|
|
|
|
func: Callable,
|
|
|
|
|
types: tuple,
|
|
|
|
|
args: tuple = (),
|
|
|
|
|
kwargs: Optional[dict] = None,
|
|
|
|
|
) -> Any:
|
|
|
|
|
"""
|
|
|
|
|
Handle __torch_function__ calls for wrapped operators.
|
|
|
|
|
"""
|
|
|
|
|
if kwargs is None:
|
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
|
|
# Import here to avoid circular imports
|
|
|
|
|
from . import _Tensor
|
|
|
|
|
|
|
|
|
|
# Use the torch function mechanism from _Tensor
|
|
|
|
|
return _Tensor.__torch_function__(func, types, args, kwargs)
|