diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_subtract.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_subtract.cpp new file mode 100644 index 00000000000..933aa1d9a59 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_subtract.cpp @@ -0,0 +1,17 @@ +#include +#include +#include + +using torch::stable::Tensor; + +Tensor my_subtract(const Tensor& self, const Tensor& other, double alpha) { + return torch::stable::subtract(self, other, alpha); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_subtract(Tensor self, Tensor other, float alpha=1.0) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_subtract", TORCH_BOX(&my_subtract)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index 17937a87786..f1bd463d3de 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -659,3 +659,19 @@ def my_full( return torch.ops.libtorch_agnostic_2_10.my_full.default( size, fill_value, dtype, layout, device, pin_memory ) + + +def my_subtract(self, other, alpha=1.0) -> Tensor: + """ + Subtracts other from self, scaled by alpha. + + Computes: self - alpha * other + + Args: + self: Tensor - input tensor + other: Tensor - tensor to subtract + alpha: float - scaling factor for other (default: 1.0) + + Returns: Tensor - result of subtraction + """ + return torch.ops.libtorch_agnostic_2_10.my_subtract.default(self, other, alpha) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp index e03efa7f0d3..8ce093241d6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/kernel.cpp @@ -496,3 +496,33 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { m.impl("my_storage_offset", TORCH_BOX(&my_storage_offset)); m.impl("my_element_size", TORCH_BOX(&my_element_size)); } + +Tensor my_unsqueeze(const Tensor& t, int64_t dim) { + return torch::stable::unsqueeze(t, dim); +} + +Tensor my_squeeze(const Tensor& t, int64_t dim) { + return torch::stable::squeeze(t, dim); +} + +Tensor my_select(const Tensor& t, int64_t dim, int64_t index) { + return torch::stable::select(t, dim, index); +} + +Tensor my_matmul(const Tensor& self, const Tensor& other) { + return torch::stable::matmul(self, other); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_9, m) { + m.def("my_unsqueeze(Tensor t, int dim) -> Tensor"); + m.def("my_squeeze(Tensor t, int dim) -> Tensor"); + m.def("my_select(Tensor t, int dim, int index) -> Tensor"); + m.def("my_matmul(Tensor self, Tensor other) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_9, CompositeExplicitAutograd, m) { + m.impl("my_unsqueeze", TORCH_BOX(&my_unsqueeze)); + m.impl("my_squeeze", TORCH_BOX(&my_squeeze)); + m.impl("my_select", TORCH_BOX(&my_select)); + m.impl("my_matmul", TORCH_BOX(&my_matmul)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py index 01c189d2a2d..2ddf62bb401 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/ops.py @@ -401,3 +401,56 @@ def my_element_size(t) -> int: Returns: int - element size in bytes """ return torch.ops.libtorch_agnostic_2_9.my_element_size.default(t) + + +def my_unsqueeze(t, dim) -> Tensor: + """ + Returns a new tensor with a dimension of size one inserted at the specified position. + + Args: + t: Tensor - input tensor + dim: int - the index at which to insert the singleton dimension + + Returns: Tensor - unsqueezed tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_unsqueeze.default(t, dim) + + +def my_squeeze(t, dim) -> Tensor: + """ + Returns a tensor with the specified dimension of size 1 removed. + + Args: + t: Tensor - input tensor + dim: int - the dimension to squeeze + + Returns: Tensor - squeezed tensor + """ + return torch.ops.libtorch_agnostic_2_9.my_squeeze.default(t, dim) + + +def my_select(t, dim, index) -> Tensor: + """ + Slices the tensor along the selected dimension at the given index. + + Args: + t: Tensor - input tensor + dim: int - the dimension to slice along + index: int - the index to select + + Returns: Tensor - sliced tensor with one fewer dimension + """ + return torch.ops.libtorch_agnostic_2_9.my_select.default(t, dim, index) + + +def my_matmul(self, other) -> Tensor: + """ + Matrix product of two tensors. + + Args: + self: Tensor - first tensor + other: Tensor - second tensor + + Returns: Tensor - matrix product + """ + return torch.ops.libtorch_agnostic_2_9.my_matmul.default(self, other) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 69820f86bf4..584c0d4914f 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -1541,6 +1541,140 @@ except RuntimeError as e: expected_both = t.new_zeros([4, 5], dtype=torch.int64, device="cpu") self.assertEqual(result_both, expected_both, exact_device=True) + def test_my_unsqueeze(self, device): + """Test unsqueeze op.""" + import libtorch_agnostic_2_9 as libtorch_agnostic + + t = torch.randn(3, 4, device=device) + + # Test unsqueeze at dim 0 + result = libtorch_agnostic.ops.my_unsqueeze(t, 0) + expected = torch.unsqueeze(t, 0) + self.assertEqual(result, expected) + self.assertEqual(result.shape, torch.Size([1, 3, 4])) + + # Test unsqueeze at dim 1 + result1 = libtorch_agnostic.ops.my_unsqueeze(t, 1) + expected1 = torch.unsqueeze(t, 1) + self.assertEqual(result1, expected1) + self.assertEqual(result1.shape, torch.Size([3, 1, 4])) + + # Test unsqueeze at dim -1 + result_neg = libtorch_agnostic.ops.my_unsqueeze(t, -1) + expected_neg = torch.unsqueeze(t, -1) + self.assertEqual(result_neg, expected_neg) + self.assertEqual(result_neg.shape, torch.Size([3, 4, 1])) + + def test_my_squeeze(self, device): + """Test squeeze.dim op.""" + import libtorch_agnostic_2_9 as libtorch_agnostic + + t = torch.randn(3, 1, 4, device=device) + + # Test squeeze at dim 1 (the dimension of size 1) + result = libtorch_agnostic.ops.my_squeeze(t, 1) + expected = torch.squeeze(t, 1) + self.assertEqual(result, expected) + self.assertEqual(result.shape, torch.Size([3, 4])) + + # Test squeeze at dim 0 (not size 1, should be no-op) + result0 = libtorch_agnostic.ops.my_squeeze(t, 0) + expected0 = torch.squeeze(t, 0) + self.assertEqual(result0, expected0) + self.assertEqual(result0.shape, torch.Size([3, 1, 4])) + + # Test squeeze at dim -2 (same as dim 1) + result_neg = libtorch_agnostic.ops.my_squeeze(t, -2) + expected_neg = torch.squeeze(t, -2) + self.assertEqual(result_neg, expected_neg) + self.assertEqual(result_neg.shape, torch.Size([3, 4])) + + def test_my_select(self, device): + """Test select.int op.""" + import libtorch_agnostic_2_9 as libtorch_agnostic + + t = torch.randn(3, 4, 5, device=device) + + # Test select at dim 0, index 1 + result = libtorch_agnostic.ops.my_select(t, 0, 1) + expected = torch.select(t, 0, 1) + self.assertEqual(result, expected) + self.assertEqual(result.shape, torch.Size([4, 5])) + + # Test select at dim 1, index 2 + result1 = libtorch_agnostic.ops.my_select(t, 1, 2) + expected1 = torch.select(t, 1, 2) + self.assertEqual(result1, expected1) + self.assertEqual(result1.shape, torch.Size([3, 5])) + + # Test select at dim -1, index 0 + result_neg = libtorch_agnostic.ops.my_select(t, -1, 0) + expected_neg = torch.select(t, -1, 0) + self.assertEqual(result_neg, expected_neg) + self.assertEqual(result_neg.shape, torch.Size([3, 4])) + + def test_my_matmul(self, device): + """Test matmul op.""" + import libtorch_agnostic_2_9 as libtorch_agnostic + + # Test 2D x 2D matrix multiplication + a = torch.randn(3, 4, device=device) + b = torch.randn(4, 5, device=device) + result = libtorch_agnostic.ops.my_matmul(a, b) + expected = torch.matmul(a, b) + self.assertEqual(result, expected) + self.assertEqual(result.shape, torch.Size([3, 5])) + + # Test 1D x 2D (vector-matrix) + v = torch.randn(4, device=device) + m = torch.randn(4, 5, device=device) + result_vm = libtorch_agnostic.ops.my_matmul(v, m) + expected_vm = torch.matmul(v, m) + self.assertEqual(result_vm, expected_vm) + + # Test 2D x 1D (matrix-vector) + m2 = torch.randn(3, 4, device=device) + v2 = torch.randn(4, device=device) + result_mv = libtorch_agnostic.ops.my_matmul(m2, v2) + expected_mv = torch.matmul(m2, v2) + self.assertEqual(result_mv, expected_mv) + + # Test batched matmul + batch_a = torch.randn(2, 3, 4, device=device) + batch_b = torch.randn(2, 4, 5, device=device) + result_batch = libtorch_agnostic.ops.my_matmul(batch_a, batch_b) + expected_batch = torch.matmul(batch_a, batch_b) + self.assertEqual(result_batch, expected_batch) + + @skipIfTorchVersionLessThan(2, 10) + def test_my_subtract(self, device): + """Test subtract.Tensor op.""" + import libtorch_agnostic_2_10 as libtorch_agnostic + + a = torch.randn(3, 4, device=device) + b = torch.randn(3, 4, device=device) + + # Test basic subtraction (alpha=1.0) + result = libtorch_agnostic.ops.my_subtract(a, b) + expected = torch.subtract(a, b) + self.assertEqual(result, expected) + + # Test subtraction with alpha=2.0 + result_alpha = libtorch_agnostic.ops.my_subtract(a, b, alpha=2.0) + expected_alpha = torch.subtract(a, b, alpha=2.0) + self.assertEqual(result_alpha, expected_alpha) + + # Test subtraction with alpha=0.5 + result_half = libtorch_agnostic.ops.my_subtract(a, b, alpha=0.5) + expected_half = torch.subtract(a, b, alpha=0.5) + self.assertEqual(result_half, expected_half) + + # Test subtraction with broadcasting + c = torch.randn(4, device=device) + result_broadcast = libtorch_agnostic.ops.my_subtract(a, c) + expected_broadcast = torch.subtract(a, c) + self.assertEqual(result_broadcast, expected_broadcast) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h index 14040bea2e7..96348d6a129 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h @@ -21,6 +21,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, i AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_subtract_Tensor(AtenTensorHandle self, AtenTensorHandle other, double alpha, AtenTensorHandle* ret0); #ifdef __cplusplus } // extern "C" diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 2fbab96c4bd..246df3cfd8c 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -295,6 +295,82 @@ inline torch::stable::Tensor flatten( return torch::stable::detail::to(stack[0]); } +// We expect this to be the stable version of the unsqueeze op with identical +// semantics to the existing unsqueeze op. +inline torch::stable::Tensor unsqueeze( + const torch::stable::Tensor& self, + int64_t dim) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dim)}; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION)); +#else + TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::unsqueeze", "", stack.data())); +#endif + return torch::stable::detail::to(stack[0]); +} + +// We expect this to be the stable version of the squeeze.dim op with identical +// semantics to the existing squeeze.dim op. +inline torch::stable::Tensor squeeze( + const torch::stable::Tensor& self, + int64_t dim) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(dim)}; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION)); +#else + TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::squeeze", "dim", stack.data())); +#endif + return torch::stable::detail::to(stack[0]); +} + +// We expect this to be the stable version of the select.int op with identical +// semantics to the existing select.int op. +// Note: index is typed as int64_t because SymInt is not yet header-only. +inline torch::stable::Tensor select( + const torch::stable::Tensor& self, + int64_t dim, + int64_t index) { + const auto num_args = 3; + std::array stack{ + torch::stable::detail::from(self), + torch::stable::detail::from(dim), + torch::stable::detail::from(index)}; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::select", "int", stack.data(), TORCH_ABI_VERSION)); +#else + TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::select", "int", stack.data())); +#endif + return torch::stable::detail::to(stack[0]); +} + +// We expect this to be the stable version of the matmul op with identical +// semantics to the existing matmul op. +inline torch::stable::Tensor matmul( + const torch::stable::Tensor& self, + const torch::stable::Tensor& other) { + const auto num_args = 2; + std::array stack{ + torch::stable::detail::from(self), torch::stable::detail::from(other)}; +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( + "aten::matmul", "", stack.data(), TORCH_ABI_VERSION)); +#else + TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::matmul", "", stack.data())); +#endif + return torch::stable::detail::to(stack[0]); +} + #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 // New ops should be added here if they use a brand new shim API @@ -578,6 +654,21 @@ inline torch::stable::Tensor& sum_out( return out; } +// We expect this to be the stable version of the subtract.Tensor op. +// Note: alpha is typed as double because the underlying C shim API +// uses double for the Scalar parameter. We don't use torch_call_dispatcher +// as the stableivalue conversion for Scalar is not yet available as of +// 2.10 +inline torch::stable::Tensor subtract( + const torch::stable::Tensor& self, + const torch::stable::Tensor& other, + double alpha = 1.0) { + AtenTensorHandle ret0; + TORCH_ERROR_CODE_CHECK( + aoti_torch_aten_subtract_Tensor(self.get(), other.get(), alpha, &ret0)); + return torch::stable::Tensor(ret0); +} + // We expect this to be the stable version of the full.default op. // Note: fill_value is typed as double because the underlying C shim API // uses double for the Scalar parameter. We don't use torch_call_dispatcher diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 46d3327f941..f78cc85e226 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -190,4 +190,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = { "aten.new_empty.default": {}, "aten.new_zeros.default": {}, "aten.full.default": {}, + "aten.subtract.Tensor": {}, }