Add squeeze, unsqueeze, matmul, select, subtract to stable ops (#169880)

From https://github.com/pytorch/audio/blob/main/src/libtorchaudio/stable/ops.h

Technically it should have been ok not to port these but looking at these carefully I realized the subtract ported to audio ~would have undefined behavior :/~ is broken

```
inline Tensor subtract(const Tensor& self, const Tensor& other) {
  const auto num_args = 2;
  std::array<StableIValue, num_args> stack{
      torch::stable::detail::from(self), torch::stable::detail::from(other)};
  TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
      "aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION));
  return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
```

as it missed `alpha` the signature for `subtract.Tensor` is  `func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor`. ~This is also our bad as although out of bounds reads on the stableivalue stack would be caught by asan, without asan they are silent correctness issues (PR coming to fix).~

Use the old path to support this as we don't support stableivalue conversion for Scalar yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/169880
Approved by: https://github.com/albanD
ghstack dependencies: #169703, #169709, #169711, #168062, #169872
This commit is contained in:
Mikayla Gawarecki
2025-12-08 13:46:30 -08:00
committed by PyTorch MergeBot
parent 9074394398
commit dc8fe8f971
8 changed files with 343 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
using torch::stable::Tensor;
Tensor my_subtract(const Tensor& self, const Tensor& other, double alpha) {
return torch::stable::subtract(self, other, alpha);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
m.def("my_subtract(Tensor self, Tensor other, float alpha=1.0) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
m.impl("my_subtract", TORCH_BOX(&my_subtract));
}

View File

@@ -659,3 +659,19 @@ def my_full(
return torch.ops.libtorch_agnostic_2_10.my_full.default(
size, fill_value, dtype, layout, device, pin_memory
)
def my_subtract(self, other, alpha=1.0) -> Tensor:
"""
Subtracts other from self, scaled by alpha.
Computes: self - alpha * other
Args:
self: Tensor - input tensor
other: Tensor - tensor to subtract
alpha: float - scaling factor for other (default: 1.0)
Returns: Tensor - result of subtraction
"""
return torch.ops.libtorch_agnostic_2_10.my_subtract.default(self, other, alpha)

View File

@@ -496,3 +496,33 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) {
m.impl("my_storage_offset", TORCH_BOX(&my_storage_offset));
m.impl("my_element_size", TORCH_BOX(&my_element_size));
}
Tensor my_unsqueeze(const Tensor& t, int64_t dim) {
return torch::stable::unsqueeze(t, dim);
}
Tensor my_squeeze(const Tensor& t, int64_t dim) {
return torch::stable::squeeze(t, dim);
}
Tensor my_select(const Tensor& t, int64_t dim, int64_t index) {
return torch::stable::select(t, dim, index);
}
Tensor my_matmul(const Tensor& self, const Tensor& other) {
return torch::stable::matmul(self, other);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) {
m.def("my_unsqueeze(Tensor t, int dim) -> Tensor");
m.def("my_squeeze(Tensor t, int dim) -> Tensor");
m.def("my_select(Tensor t, int dim, int index) -> Tensor");
m.def("my_matmul(Tensor self, Tensor other) -> Tensor");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) {
m.impl("my_unsqueeze", TORCH_BOX(&my_unsqueeze));
m.impl("my_squeeze", TORCH_BOX(&my_squeeze));
m.impl("my_select", TORCH_BOX(&my_select));
m.impl("my_matmul", TORCH_BOX(&my_matmul));
}

View File

@@ -401,3 +401,56 @@ def my_element_size(t) -> int:
Returns: int - element size in bytes
"""
return torch.ops.libtorch_agnostic_2_9.my_element_size.default(t)
def my_unsqueeze(t, dim) -> Tensor:
"""
Returns a new tensor with a dimension of size one inserted at the specified position.
Args:
t: Tensor - input tensor
dim: int - the index at which to insert the singleton dimension
Returns: Tensor - unsqueezed tensor
"""
return torch.ops.libtorch_agnostic_2_9.my_unsqueeze.default(t, dim)
def my_squeeze(t, dim) -> Tensor:
"""
Returns a tensor with the specified dimension of size 1 removed.
Args:
t: Tensor - input tensor
dim: int - the dimension to squeeze
Returns: Tensor - squeezed tensor
"""
return torch.ops.libtorch_agnostic_2_9.my_squeeze.default(t, dim)
def my_select(t, dim, index) -> Tensor:
"""
Slices the tensor along the selected dimension at the given index.
Args:
t: Tensor - input tensor
dim: int - the dimension to slice along
index: int - the index to select
Returns: Tensor - sliced tensor with one fewer dimension
"""
return torch.ops.libtorch_agnostic_2_9.my_select.default(t, dim, index)
def my_matmul(self, other) -> Tensor:
"""
Matrix product of two tensors.
Args:
self: Tensor - first tensor
other: Tensor - second tensor
Returns: Tensor - matrix product
"""
return torch.ops.libtorch_agnostic_2_9.my_matmul.default(self, other)

View File

@@ -1541,6 +1541,140 @@ except RuntimeError as e:
expected_both = t.new_zeros([4, 5], dtype=torch.int64, device="cpu")
self.assertEqual(result_both, expected_both, exact_device=True)
def test_my_unsqueeze(self, device):
"""Test unsqueeze op."""
import libtorch_agnostic_2_9 as libtorch_agnostic
t = torch.randn(3, 4, device=device)
# Test unsqueeze at dim 0
result = libtorch_agnostic.ops.my_unsqueeze(t, 0)
expected = torch.unsqueeze(t, 0)
self.assertEqual(result, expected)
self.assertEqual(result.shape, torch.Size([1, 3, 4]))
# Test unsqueeze at dim 1
result1 = libtorch_agnostic.ops.my_unsqueeze(t, 1)
expected1 = torch.unsqueeze(t, 1)
self.assertEqual(result1, expected1)
self.assertEqual(result1.shape, torch.Size([3, 1, 4]))
# Test unsqueeze at dim -1
result_neg = libtorch_agnostic.ops.my_unsqueeze(t, -1)
expected_neg = torch.unsqueeze(t, -1)
self.assertEqual(result_neg, expected_neg)
self.assertEqual(result_neg.shape, torch.Size([3, 4, 1]))
def test_my_squeeze(self, device):
"""Test squeeze.dim op."""
import libtorch_agnostic_2_9 as libtorch_agnostic
t = torch.randn(3, 1, 4, device=device)
# Test squeeze at dim 1 (the dimension of size 1)
result = libtorch_agnostic.ops.my_squeeze(t, 1)
expected = torch.squeeze(t, 1)
self.assertEqual(result, expected)
self.assertEqual(result.shape, torch.Size([3, 4]))
# Test squeeze at dim 0 (not size 1, should be no-op)
result0 = libtorch_agnostic.ops.my_squeeze(t, 0)
expected0 = torch.squeeze(t, 0)
self.assertEqual(result0, expected0)
self.assertEqual(result0.shape, torch.Size([3, 1, 4]))
# Test squeeze at dim -2 (same as dim 1)
result_neg = libtorch_agnostic.ops.my_squeeze(t, -2)
expected_neg = torch.squeeze(t, -2)
self.assertEqual(result_neg, expected_neg)
self.assertEqual(result_neg.shape, torch.Size([3, 4]))
def test_my_select(self, device):
"""Test select.int op."""
import libtorch_agnostic_2_9 as libtorch_agnostic
t = torch.randn(3, 4, 5, device=device)
# Test select at dim 0, index 1
result = libtorch_agnostic.ops.my_select(t, 0, 1)
expected = torch.select(t, 0, 1)
self.assertEqual(result, expected)
self.assertEqual(result.shape, torch.Size([4, 5]))
# Test select at dim 1, index 2
result1 = libtorch_agnostic.ops.my_select(t, 1, 2)
expected1 = torch.select(t, 1, 2)
self.assertEqual(result1, expected1)
self.assertEqual(result1.shape, torch.Size([3, 5]))
# Test select at dim -1, index 0
result_neg = libtorch_agnostic.ops.my_select(t, -1, 0)
expected_neg = torch.select(t, -1, 0)
self.assertEqual(result_neg, expected_neg)
self.assertEqual(result_neg.shape, torch.Size([3, 4]))
def test_my_matmul(self, device):
"""Test matmul op."""
import libtorch_agnostic_2_9 as libtorch_agnostic
# Test 2D x 2D matrix multiplication
a = torch.randn(3, 4, device=device)
b = torch.randn(4, 5, device=device)
result = libtorch_agnostic.ops.my_matmul(a, b)
expected = torch.matmul(a, b)
self.assertEqual(result, expected)
self.assertEqual(result.shape, torch.Size([3, 5]))
# Test 1D x 2D (vector-matrix)
v = torch.randn(4, device=device)
m = torch.randn(4, 5, device=device)
result_vm = libtorch_agnostic.ops.my_matmul(v, m)
expected_vm = torch.matmul(v, m)
self.assertEqual(result_vm, expected_vm)
# Test 2D x 1D (matrix-vector)
m2 = torch.randn(3, 4, device=device)
v2 = torch.randn(4, device=device)
result_mv = libtorch_agnostic.ops.my_matmul(m2, v2)
expected_mv = torch.matmul(m2, v2)
self.assertEqual(result_mv, expected_mv)
# Test batched matmul
batch_a = torch.randn(2, 3, 4, device=device)
batch_b = torch.randn(2, 4, 5, device=device)
result_batch = libtorch_agnostic.ops.my_matmul(batch_a, batch_b)
expected_batch = torch.matmul(batch_a, batch_b)
self.assertEqual(result_batch, expected_batch)
@skipIfTorchVersionLessThan(2, 10)
def test_my_subtract(self, device):
"""Test subtract.Tensor op."""
import libtorch_agnostic_2_10 as libtorch_agnostic
a = torch.randn(3, 4, device=device)
b = torch.randn(3, 4, device=device)
# Test basic subtraction (alpha=1.0)
result = libtorch_agnostic.ops.my_subtract(a, b)
expected = torch.subtract(a, b)
self.assertEqual(result, expected)
# Test subtraction with alpha=2.0
result_alpha = libtorch_agnostic.ops.my_subtract(a, b, alpha=2.0)
expected_alpha = torch.subtract(a, b, alpha=2.0)
self.assertEqual(result_alpha, expected_alpha)
# Test subtraction with alpha=0.5
result_half = libtorch_agnostic.ops.my_subtract(a, b, alpha=0.5)
expected_half = torch.subtract(a, b, alpha=0.5)
self.assertEqual(result_half, expected_half)
# Test subtraction with broadcasting
c = torch.randn(4, device=device)
result_broadcast = libtorch_agnostic.ops.my_subtract(a, c)
expected_broadcast = torch.subtract(a, c)
self.assertEqual(result_broadcast, expected_broadcast)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@@ -21,6 +21,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, i
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_subtract_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0);
#ifdef __cplusplus
} // extern "C"

View File

@@ -295,6 +295,82 @@ inline torch::stable::Tensor flatten(
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the unsqueeze op with identical
// semantics to the existing unsqueeze op.
inline torch::stable::Tensor unsqueeze(
const torch::stable::Tensor& self,
int64_t dim) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::unsqueeze", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the squeeze.dim op with identical
// semantics to the existing squeeze.dim op.
inline torch::stable::Tensor squeeze(
const torch::stable::Tensor& self,
int64_t dim) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::squeeze", "dim", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the select.int op with identical
// semantics to the existing select.int op.
// Note: index is typed as int64_t because SymInt is not yet header-only.
inline torch::stable::Tensor select(
const torch::stable::Tensor& self,
int64_t dim,
int64_t index) {
const auto num_args = 3;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self),
torch::stable::detail::from(dim),
torch::stable::detail::from(index)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::select", "int", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::select", "int", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
// We expect this to be the stable version of the matmul op with identical
// semantics to the existing matmul op.
inline torch::stable::Tensor matmul(
const torch::stable::Tensor& self,
const torch::stable::Tensor& other) {
const auto num_args = 2;
std::array<StableIValue, num_args> stack{
torch::stable::detail::from(self), torch::stable::detail::from(other)};
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::matmul", "", stack.data(), TORCH_ABI_VERSION));
#else
TORCH_ERROR_CODE_CHECK(
aoti_torch_call_dispatcher("aten::matmul", "", stack.data()));
#endif
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
}
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
// New ops should be added here if they use a brand new shim API
@@ -578,6 +654,21 @@ inline torch::stable::Tensor& sum_out(
return out;
}
// We expect this to be the stable version of the subtract.Tensor op.
// Note: alpha is typed as double because the underlying C shim API
// uses double for the Scalar parameter. We don't use torch_call_dispatcher
// as the stableivalue conversion for Scalar is not yet available as of
// 2.10
inline torch::stable::Tensor subtract(
const torch::stable::Tensor& self,
const torch::stable::Tensor& other,
double alpha = 1.0) {
AtenTensorHandle ret0;
TORCH_ERROR_CODE_CHECK(
aoti_torch_aten_subtract_Tensor(self.get(), other.get(), alpha, &ret0));
return torch::stable::Tensor(ret0);
}
// We expect this to be the stable version of the full.default op.
// Note: fill_value is typed as double because the underlying C shim API
// uses double for the Scalar parameter. We don't use torch_call_dispatcher

View File

@@ -190,4 +190,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
"aten.new_empty.default": {},
"aten.new_zeros.default": {},
"aten.full.default": {},
"aten.subtract.Tensor": {},
}