Files

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

268 lines
8.1 KiB
Python
Raw Permalink Normal View History

torchdim Python port (#160236) The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
"""
Python implementation of function wrapping functionality for functorch.dim.
"""
from __future__ import annotations
import functools
from typing import Any, Optional, TYPE_CHECKING
torchdim Python port (#160236) The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
import torch
from torch.utils._pytree import tree_map
from ._dim_entry import DimEntry
from ._enable_all_layers import EnableAllLayers
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from collections.abc import Callable
torchdim Python port (#160236) The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-20 12:21:15 -04:00
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Handle tensor conversion for torch function integration."""
return tensor
class WrappedOperator:
"""
This class wraps PyTorch operations to support first-class dimensions.
"""
def __init__(
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
):
self.orig = orig
self.wrapper_implementation = wrapper_implementation
self.name = getattr(orig, "__name__", "")
self.doc = getattr(orig, "__doc__", None)
self.dim_name = dim_name
self.is_pointwise = False
self.dim_offset = 0
self.keepdim_offset = 1
self.single_dim = False
self.reduce = True
# Update docstring if we have a dim_name
if self.doc and self.dim_name:
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
def function(self) -> Callable:
"""Create a wrapped function that calls our wrapper implementation."""
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
return self.wrapper_implementation(self, *args, **kwargs)
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
functools.update_wrapper(
wrapped_func, self.orig, assigned=("__name__",), updated=()
)
wrapped_func.__doc__ = self.doc
return wrapped_func
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
"""Convert single dimension specification to DimEntry object."""
from . import Dim
if isinstance(dim, Dim):
if keepdim:
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
return DimEntry(dim)
elif isinstance(dim, int):
i = dim
while i >= 0:
i -= ndim
return DimEntry(i)
else:
return DimEntry()
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
"""Convert dimension specification to list of DimEntry objects."""
de = _wrap_dim(dim, ndim, keepdim)
result = []
if not de.is_none():
result.append(de)
else:
for d in dim:
result.append(_wrap_dim(d, ndim, keepdim))
return result
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
"""
This is the core method that handles dimension-aware operations.
"""
if not args:
raise ValueError("Expected at least one argument (self)")
# Get dimension argument
dim_arg = kwargs.get(wrapper.dim_name)
if dim_arg is None and wrapper.dim_offset < len(args):
# Try to get dim from positional args (accounting for self at index 0)
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(args):
dim_arg = args[dim_idx]
# If no dimension argument provided, fall back to standard functorch handling
if dim_arg is None:
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
if not info:
return wrapper.orig(*args, **kwargs)
with EnableAllLayers(info.levels) as guard:
assert info.batchedtensor is not None
guard.inplace_update_layers(info.batchedtensor, info.levels)
new_args = list(args)
new_args[0] = handle_from_tensor(info.batchedtensor)
result = wrapper.orig(*new_args, **kwargs)
return guard.from_batched(result, info.has_device)
# Handle dimension-aware operation
info = TensorInfo.create(args[0])
if not info:
return wrapper.orig(*args, **kwargs)
# Check for keepdim parameter
keepdim = False
if wrapper.reduce:
keepdim_arg = kwargs.get("keepdim")
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
keepdim_idx = wrapper.keepdim_offset + 1
if keepdim_idx < len(args):
keepdim_arg = args[keepdim_idx]
if keepdim_arg is not None:
keepdim = bool(keepdim_arg)
# Wrap dimensions
ndim = info.ndim()
dims = _wrap_dims(dim_arg, ndim, keepdim)
# Convert dimensions to indices and validate
dim_indices: list[int] = []
seen = [False] * len(info.levels)
for d in dims:
midx = None
for i, level in enumerate(info.levels):
if level == d:
midx = i
break
if midx is None:
# Try to match by position/name more flexibly
for i, level in enumerate(info.levels):
if hasattr(level, "matches") and level.matches(d):
midx = i
break
if midx is None:
level_strs = [str(level) for level in info.levels]
raise ValueError(
f"Tensor with dimensions {level_strs} does not contain {d}"
)
seen[midx] = True
dim_indices.append(midx)
# Determine new levels after reduction
new_levels = []
if wrapper.reduce and not keepdim:
for i, level in enumerate(info.levels):
if not seen[i]:
new_levels.append(level)
else:
new_levels = info.levels[:]
# Create dimension indices for the original function
if len(dim_indices) == 1:
py_indices: Any = dim_indices[0]
else:
py_indices = tuple(dim_indices)
# Update arguments
new_args = list(args)
new_kwargs = kwargs.copy()
assert info.tensor is not None
new_args[0] = handle_from_tensor(info.tensor)
# Update dimension argument
if wrapper.dim_name in new_kwargs:
new_kwargs[wrapper.dim_name] = py_indices
else:
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(new_args):
new_args = list(new_args)
new_args[dim_idx] = py_indices
# Call original function
result = wrapper.orig(*new_args, **new_kwargs)
# Wrap results
def wrap_result(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
from . import Tensor
return Tensor.from_positional(obj, new_levels, info.has_device)
return obj
return tree_map(wrap_result, result)
def _wrap(
orig: Callable,
dim_offset: Optional[int] = None,
keepdim_offset: Optional[int] = None,
dim_name: Optional[str] = None,
single_dim: Optional[bool] = None,
reduce: Optional[bool] = None,
) -> Callable:
"""
Wrap a PyTorch function to support first-class dimensions.
Args:
orig: Original function to wrap
dim_offset: Offset for dimension argument (default: 0)
keepdim_offset: Offset for keepdim argument (default: 1)
dim_name: Name of dimension parameter (default: "dim")
single_dim: Whether function takes single dimension (default: False)
reduce: Whether function reduces dimensions (default: True)
"""
dim_name = dim_name or "dim"
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
if dim_offset is not None:
wrapper.dim_offset = dim_offset
if keepdim_offset is not None:
wrapper.keepdim_offset = keepdim_offset
if single_dim is not None:
wrapper.single_dim = single_dim
if reduce is not None:
wrapper.reduce = reduce
return wrapper.function()
def call_torch_function(
wrapper: WrappedOperator,
func: Callable,
types: tuple,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> Any:
"""
Handle __torch_function__ calls for wrapped operators.
"""
if kwargs is None:
kwargs = {}
# Import here to avoid circular imports
from . import _Tensor
# Use the torch function mechanism from _Tensor
return _Tensor.__torch_function__(func, types, args, kwargs)