mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Add squeeze, unsqueeze, matmul, select, subtract to stable ops (#169880)
From https://github.com/pytorch/audio/blob/main/src/libtorchaudio/stable/ops.h Technically it should have been ok not to port these but looking at these carefully I realized the subtract ported to audio ~would have undefined behavior :/~ is broken ``` inline Tensor subtract(const Tensor& self, const Tensor& other) { const auto num_args = 2; std::array<StableIValue, num_args> stack{ torch::stable::detail::from(self), torch::stable::detail::from(other)}; TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( "aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to<torch::stable::Tensor>(stack[0]); } ``` as it missed `alpha` the signature for `subtract.Tensor` is `func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor`. ~This is also our bad as although out of bounds reads on the stableivalue stack would be caught by asan, without asan they are silent correctness issues (PR coming to fix).~ Use the old path to support this as we don't support stableivalue conversion for Scalar yet. Pull Request resolved: https://github.com/pytorch/pytorch/pull/169880 Approved by: https://github.com/albanD ghstack dependencies: #169703, #169709, #169711, #168062, #169872
This commit is contained in:
committed by
PyTorch MergeBot
parent
9074394398
commit
dc8fe8f971
@@ -0,0 +1,17 @@
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
using torch::stable::Tensor;
|
||||
|
||||
Tensor my_subtract(const Tensor& self, const Tensor& other, double alpha) {
|
||||
return torch::stable::subtract(self, other, alpha);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) {
|
||||
m.def("my_subtract(Tensor self, Tensor other, float alpha=1.0) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_subtract", TORCH_BOX(&my_subtract));
|
||||
}
|
||||
@@ -659,3 +659,19 @@ def my_full(
|
||||
return torch.ops.libtorch_agnostic_2_10.my_full.default(
|
||||
size, fill_value, dtype, layout, device, pin_memory
|
||||
)
|
||||
|
||||
|
||||
def my_subtract(self, other, alpha=1.0) -> Tensor:
|
||||
"""
|
||||
Subtracts other from self, scaled by alpha.
|
||||
|
||||
Computes: self - alpha * other
|
||||
|
||||
Args:
|
||||
self: Tensor - input tensor
|
||||
other: Tensor - tensor to subtract
|
||||
alpha: float - scaling factor for other (default: 1.0)
|
||||
|
||||
Returns: Tensor - result of subtraction
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_10.my_subtract.default(self, other, alpha)
|
||||
|
||||
@@ -496,3 +496,33 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_storage_offset", TORCH_BOX(&my_storage_offset));
|
||||
m.impl("my_element_size", TORCH_BOX(&my_element_size));
|
||||
}
|
||||
|
||||
Tensor my_unsqueeze(const Tensor& t, int64_t dim) {
|
||||
return torch::stable::unsqueeze(t, dim);
|
||||
}
|
||||
|
||||
Tensor my_squeeze(const Tensor& t, int64_t dim) {
|
||||
return torch::stable::squeeze(t, dim);
|
||||
}
|
||||
|
||||
Tensor my_select(const Tensor& t, int64_t dim, int64_t index) {
|
||||
return torch::stable::select(t, dim, index);
|
||||
}
|
||||
|
||||
Tensor my_matmul(const Tensor& self, const Tensor& other) {
|
||||
return torch::stable::matmul(self, other);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) {
|
||||
m.def("my_unsqueeze(Tensor t, int dim) -> Tensor");
|
||||
m.def("my_squeeze(Tensor t, int dim) -> Tensor");
|
||||
m.def("my_select(Tensor t, int dim, int index) -> Tensor");
|
||||
m.def("my_matmul(Tensor self, Tensor other) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_unsqueeze", TORCH_BOX(&my_unsqueeze));
|
||||
m.impl("my_squeeze", TORCH_BOX(&my_squeeze));
|
||||
m.impl("my_select", TORCH_BOX(&my_select));
|
||||
m.impl("my_matmul", TORCH_BOX(&my_matmul));
|
||||
}
|
||||
|
||||
@@ -401,3 +401,56 @@ def my_element_size(t) -> int:
|
||||
Returns: int - element size in bytes
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_9.my_element_size.default(t)
|
||||
|
||||
|
||||
def my_unsqueeze(t, dim) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with a dimension of size one inserted at the specified position.
|
||||
|
||||
Args:
|
||||
t: Tensor - input tensor
|
||||
dim: int - the index at which to insert the singleton dimension
|
||||
|
||||
Returns: Tensor - unsqueezed tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_9.my_unsqueeze.default(t, dim)
|
||||
|
||||
|
||||
def my_squeeze(t, dim) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with the specified dimension of size 1 removed.
|
||||
|
||||
Args:
|
||||
t: Tensor - input tensor
|
||||
dim: int - the dimension to squeeze
|
||||
|
||||
Returns: Tensor - squeezed tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_9.my_squeeze.default(t, dim)
|
||||
|
||||
|
||||
def my_select(t, dim, index) -> Tensor:
|
||||
"""
|
||||
Slices the tensor along the selected dimension at the given index.
|
||||
|
||||
Args:
|
||||
t: Tensor - input tensor
|
||||
dim: int - the dimension to slice along
|
||||
index: int - the index to select
|
||||
|
||||
Returns: Tensor - sliced tensor with one fewer dimension
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_9.my_select.default(t, dim, index)
|
||||
|
||||
|
||||
def my_matmul(self, other) -> Tensor:
|
||||
"""
|
||||
Matrix product of two tensors.
|
||||
|
||||
Args:
|
||||
self: Tensor - first tensor
|
||||
other: Tensor - second tensor
|
||||
|
||||
Returns: Tensor - matrix product
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic_2_9.my_matmul.default(self, other)
|
||||
|
||||
@@ -1541,6 +1541,140 @@ except RuntimeError as e:
|
||||
expected_both = t.new_zeros([4, 5], dtype=torch.int64, device="cpu")
|
||||
self.assertEqual(result_both, expected_both, exact_device=True)
|
||||
|
||||
def test_my_unsqueeze(self, device):
|
||||
"""Test unsqueeze op."""
|
||||
import libtorch_agnostic_2_9 as libtorch_agnostic
|
||||
|
||||
t = torch.randn(3, 4, device=device)
|
||||
|
||||
# Test unsqueeze at dim 0
|
||||
result = libtorch_agnostic.ops.my_unsqueeze(t, 0)
|
||||
expected = torch.unsqueeze(t, 0)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(result.shape, torch.Size([1, 3, 4]))
|
||||
|
||||
# Test unsqueeze at dim 1
|
||||
result1 = libtorch_agnostic.ops.my_unsqueeze(t, 1)
|
||||
expected1 = torch.unsqueeze(t, 1)
|
||||
self.assertEqual(result1, expected1)
|
||||
self.assertEqual(result1.shape, torch.Size([3, 1, 4]))
|
||||
|
||||
# Test unsqueeze at dim -1
|
||||
result_neg = libtorch_agnostic.ops.my_unsqueeze(t, -1)
|
||||
expected_neg = torch.unsqueeze(t, -1)
|
||||
self.assertEqual(result_neg, expected_neg)
|
||||
self.assertEqual(result_neg.shape, torch.Size([3, 4, 1]))
|
||||
|
||||
def test_my_squeeze(self, device):
|
||||
"""Test squeeze.dim op."""
|
||||
import libtorch_agnostic_2_9 as libtorch_agnostic
|
||||
|
||||
t = torch.randn(3, 1, 4, device=device)
|
||||
|
||||
# Test squeeze at dim 1 (the dimension of size 1)
|
||||
result = libtorch_agnostic.ops.my_squeeze(t, 1)
|
||||
expected = torch.squeeze(t, 1)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(result.shape, torch.Size([3, 4]))
|
||||
|
||||
# Test squeeze at dim 0 (not size 1, should be no-op)
|
||||
result0 = libtorch_agnostic.ops.my_squeeze(t, 0)
|
||||
expected0 = torch.squeeze(t, 0)
|
||||
self.assertEqual(result0, expected0)
|
||||
self.assertEqual(result0.shape, torch.Size([3, 1, 4]))
|
||||
|
||||
# Test squeeze at dim -2 (same as dim 1)
|
||||
result_neg = libtorch_agnostic.ops.my_squeeze(t, -2)
|
||||
expected_neg = torch.squeeze(t, -2)
|
||||
self.assertEqual(result_neg, expected_neg)
|
||||
self.assertEqual(result_neg.shape, torch.Size([3, 4]))
|
||||
|
||||
def test_my_select(self, device):
|
||||
"""Test select.int op."""
|
||||
import libtorch_agnostic_2_9 as libtorch_agnostic
|
||||
|
||||
t = torch.randn(3, 4, 5, device=device)
|
||||
|
||||
# Test select at dim 0, index 1
|
||||
result = libtorch_agnostic.ops.my_select(t, 0, 1)
|
||||
expected = torch.select(t, 0, 1)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(result.shape, torch.Size([4, 5]))
|
||||
|
||||
# Test select at dim 1, index 2
|
||||
result1 = libtorch_agnostic.ops.my_select(t, 1, 2)
|
||||
expected1 = torch.select(t, 1, 2)
|
||||
self.assertEqual(result1, expected1)
|
||||
self.assertEqual(result1.shape, torch.Size([3, 5]))
|
||||
|
||||
# Test select at dim -1, index 0
|
||||
result_neg = libtorch_agnostic.ops.my_select(t, -1, 0)
|
||||
expected_neg = torch.select(t, -1, 0)
|
||||
self.assertEqual(result_neg, expected_neg)
|
||||
self.assertEqual(result_neg.shape, torch.Size([3, 4]))
|
||||
|
||||
def test_my_matmul(self, device):
|
||||
"""Test matmul op."""
|
||||
import libtorch_agnostic_2_9 as libtorch_agnostic
|
||||
|
||||
# Test 2D x 2D matrix multiplication
|
||||
a = torch.randn(3, 4, device=device)
|
||||
b = torch.randn(4, 5, device=device)
|
||||
result = libtorch_agnostic.ops.my_matmul(a, b)
|
||||
expected = torch.matmul(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
self.assertEqual(result.shape, torch.Size([3, 5]))
|
||||
|
||||
# Test 1D x 2D (vector-matrix)
|
||||
v = torch.randn(4, device=device)
|
||||
m = torch.randn(4, 5, device=device)
|
||||
result_vm = libtorch_agnostic.ops.my_matmul(v, m)
|
||||
expected_vm = torch.matmul(v, m)
|
||||
self.assertEqual(result_vm, expected_vm)
|
||||
|
||||
# Test 2D x 1D (matrix-vector)
|
||||
m2 = torch.randn(3, 4, device=device)
|
||||
v2 = torch.randn(4, device=device)
|
||||
result_mv = libtorch_agnostic.ops.my_matmul(m2, v2)
|
||||
expected_mv = torch.matmul(m2, v2)
|
||||
self.assertEqual(result_mv, expected_mv)
|
||||
|
||||
# Test batched matmul
|
||||
batch_a = torch.randn(2, 3, 4, device=device)
|
||||
batch_b = torch.randn(2, 4, 5, device=device)
|
||||
result_batch = libtorch_agnostic.ops.my_matmul(batch_a, batch_b)
|
||||
expected_batch = torch.matmul(batch_a, batch_b)
|
||||
self.assertEqual(result_batch, expected_batch)
|
||||
|
||||
@skipIfTorchVersionLessThan(2, 10)
|
||||
def test_my_subtract(self, device):
|
||||
"""Test subtract.Tensor op."""
|
||||
import libtorch_agnostic_2_10 as libtorch_agnostic
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
b = torch.randn(3, 4, device=device)
|
||||
|
||||
# Test basic subtraction (alpha=1.0)
|
||||
result = libtorch_agnostic.ops.my_subtract(a, b)
|
||||
expected = torch.subtract(a, b)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# Test subtraction with alpha=2.0
|
||||
result_alpha = libtorch_agnostic.ops.my_subtract(a, b, alpha=2.0)
|
||||
expected_alpha = torch.subtract(a, b, alpha=2.0)
|
||||
self.assertEqual(result_alpha, expected_alpha)
|
||||
|
||||
# Test subtraction with alpha=0.5
|
||||
result_half = libtorch_agnostic.ops.my_subtract(a, b, alpha=0.5)
|
||||
expected_half = torch.subtract(a, b, alpha=0.5)
|
||||
self.assertEqual(result_half, expected_half)
|
||||
|
||||
# Test subtraction with broadcasting
|
||||
c = torch.randn(4, device=device)
|
||||
result_broadcast = libtorch_agnostic.ops.my_subtract(a, c)
|
||||
expected_broadcast = torch.subtract(a, c)
|
||||
self.assertEqual(result_broadcast, expected_broadcast)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,6 +21,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, i
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_subtract_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
|
||||
@@ -295,6 +295,82 @@ inline torch::stable::Tensor flatten(
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the unsqueeze op with identical
|
||||
// semantics to the existing unsqueeze op.
|
||||
inline torch::stable::Tensor unsqueeze(
|
||||
const torch::stable::Tensor& self,
|
||||
int64_t dim) {
|
||||
const auto num_args = 2;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::unsqueeze", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the squeeze.dim op with identical
|
||||
// semantics to the existing squeeze.dim op.
|
||||
inline torch::stable::Tensor squeeze(
|
||||
const torch::stable::Tensor& self,
|
||||
int64_t dim) {
|
||||
const auto num_args = 2;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::squeeze", "dim", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the select.int op with identical
|
||||
// semantics to the existing select.int op.
|
||||
// Note: index is typed as int64_t because SymInt is not yet header-only.
|
||||
inline torch::stable::Tensor select(
|
||||
const torch::stable::Tensor& self,
|
||||
int64_t dim,
|
||||
int64_t index) {
|
||||
const auto num_args = 3;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
torch::stable::detail::from(self),
|
||||
torch::stable::detail::from(dim),
|
||||
torch::stable::detail::from(index)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::select", "int", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::select", "int", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the matmul op with identical
|
||||
// semantics to the existing matmul op.
|
||||
inline torch::stable::Tensor matmul(
|
||||
const torch::stable::Tensor& self,
|
||||
const torch::stable::Tensor& other) {
|
||||
const auto num_args = 2;
|
||||
std::array<StableIValue, num_args> stack{
|
||||
torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::matmul", "", stack.data(), TORCH_ABI_VERSION));
|
||||
#else
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_call_dispatcher("aten::matmul", "", stack.data()));
|
||||
#endif
|
||||
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
||||
|
||||
// New ops should be added here if they use a brand new shim API
|
||||
@@ -578,6 +654,21 @@ inline torch::stable::Tensor& sum_out(
|
||||
return out;
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the subtract.Tensor op.
|
||||
// Note: alpha is typed as double because the underlying C shim API
|
||||
// uses double for the Scalar parameter. We don't use torch_call_dispatcher
|
||||
// as the stableivalue conversion for Scalar is not yet available as of
|
||||
// 2.10
|
||||
inline torch::stable::Tensor subtract(
|
||||
const torch::stable::Tensor& self,
|
||||
const torch::stable::Tensor& other,
|
||||
double alpha = 1.0) {
|
||||
AtenTensorHandle ret0;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_aten_subtract_Tensor(self.get(), other.get(), alpha, &ret0));
|
||||
return torch::stable::Tensor(ret0);
|
||||
}
|
||||
|
||||
// We expect this to be the stable version of the full.default op.
|
||||
// Note: fill_value is typed as double because the underlying C shim API
|
||||
// uses double for the Scalar parameter. We don't use torch_call_dispatcher
|
||||
|
||||
@@ -190,4 +190,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = {
|
||||
"aten.new_empty.default": {},
|
||||
"aten.new_zeros.default": {},
|
||||
"aten.full.default": {},
|
||||
"aten.subtract.Tensor": {},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user