From 9bfdf575724942f9ebcd6a54f3590357e5f43b8a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 29 Apr 2025 20:47:22 -0700 Subject: [PATCH] [MPS][BE] Introduce `c10::metal::mul` (#152466) Which multiplies two arguments for either scalar or complex data types This allows one to get rid of bunch of complex specialization in BinaryOps Pull Request resolved: https://github.com/pytorch/pytorch/pull/152466 Approved by: https://github.com/dcci ghstack dependencies: #152443 --- .../native/mps/kernels/BinaryKernel.metal | 55 ++++--------------- .../ATen/native/mps/operations/BinaryOps.mm | 6 +- c10/metal/utils.h | 20 +++++++ 3 files changed, 31 insertions(+), 50 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 4c585d122b1..64e44b370d3 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -21,21 +21,21 @@ struct sub_functor { struct add_alpha_functor { template inline T operator()(const T a, const T b, const T alpha) { - return static_cast(a + (alpha * b)); + return static_cast(a + c10::metal::mul(alpha, b)); } }; struct sub_alpha_functor { template inline T operator()(const T a, const T b, const T alpha) { - return static_cast(a - (alpha * b)); + return static_cast(a - c10::metal::mul(alpha, b)); } }; struct lerp_alpha_functor { template inline T operator()(const T a, const T b, const T alpha) { - return static_cast(a + (alpha * (b - a))); + return static_cast(a + c10::metal::mul(alpha, b - a)); } }; @@ -183,42 +183,7 @@ struct make_complex_functor { struct complex_mul_functor { template inline T operator()(const T a, const T b) { - return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); - } -}; - -struct complex_add_alpha_functor { - template - inline T operator()(const T a, const T b, const T alpha) { - return T( - a.x + (alpha.x * b.x - alpha.y * b.y), - a.y + (alpha.x * b.y + alpha.y * b.x)); - } -}; - -struct complex_add_functor { - template - inline T operator()(const T a, const T b) { - return T(a.x + b.x, a.y + b.y); - } -}; - -struct complex_sub_alpha_functor { - template - inline T operator()(const T a, const T b, const T alpha) { - return T( - a.x - (alpha.x * b.x - alpha.y * b.y), - a.y - (alpha.x * b.y + alpha.y * b.x)); - } -}; - -struct complex_lerp_alpha_functor { - template - inline T operator()(const T a, const T b, const T alpha) { - auto intr = T(b.x - a.x, b.y - a.y); - return T( - a.x + (alpha.x * intr.x - intr.y * intr.y), - a.y + (alpha.x * intr.y + alpha.y * intr.x)); + return c10::metal::mul(a, b); } }; @@ -324,9 +289,9 @@ REGISTER_BINARY_OP(add, float2, float2); REGISTER_BINARY_OP(add, half2, half2); REGISTER_BINARY_OP(sub, float2, float2); REGISTER_BINARY_OP(sub, half2, half2); -REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2); diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 62984bc5e51..aa5b8f262c1 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -278,11 +278,7 @@ static void add_sub_lerp_template(const Tensor& self, if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) { if (alpha_has_value) { at::native::alpha_check(commonDtype, alpha); - mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name, - self, - other, - output, - getMPSScalar(alpha, commonDtype)); + mps::binary_op_kernel(op_name, self, other, output, getMPSScalar(alpha, commonDtype)); } else { mps::binary_op_kernel(op_name, self, other, output); } diff --git a/c10/metal/utils.h b/c10/metal/utils.h index a763cb3706d..198cac71416 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -164,5 +164,25 @@ cast_to(const U from) { return static_cast(from); } +// Generalizable math operators (used for both scalar and complex) +template +using common_dtype = decltype(U(0) + V(0)); + +template < + typename T, + typename U, + ::metal::enable_if_t, bool> = true> +inline common_dtype mul(const T x, const U y) { + return x * y; +} + +template < + typename T, + typename U, + ::metal::enable_if_t && is_complex_v, bool> = true> +inline common_dtype mul(const T x, const U y) { + return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x); +} + } // namespace metal } // namespace c10