[MPS][BE] Introduce c10::metal::mul (#152466)

Which multiplies two arguments for either scalar or complex data types

This allows one to get rid of bunch of complex specialization in BinaryOps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152466
Approved by: https://github.com/dcci
ghstack dependencies: #152443
This commit is contained in:
Nikita Shulga
2025-04-29 20:47:22 -07:00
committed by PyTorch MergeBot
parent ee2d104c05
commit 9bfdf57572
3 changed files with 31 additions and 50 deletions

View File

@@ -21,21 +21,21 @@ struct sub_functor {
struct add_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a + (alpha * b));
return static_cast<T>(a + c10::metal::mul(alpha, b));
}
};
struct sub_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a - (alpha * b));
return static_cast<T>(a - c10::metal::mul(alpha, b));
}
};
struct lerp_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a + (alpha * (b - a)));
return static_cast<T>(a + c10::metal::mul(alpha, b - a));
}
};
@@ -183,42 +183,7 @@ struct make_complex_functor {
struct complex_mul_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
};
struct complex_add_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return T(
a.x + (alpha.x * b.x - alpha.y * b.y),
a.y + (alpha.x * b.y + alpha.y * b.x));
}
};
struct complex_add_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(a.x + b.x, a.y + b.y);
}
};
struct complex_sub_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return T(
a.x - (alpha.x * b.x - alpha.y * b.y),
a.y - (alpha.x * b.y + alpha.y * b.x));
}
};
struct complex_lerp_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
auto intr = T(b.x - a.x, b.y - a.y);
return T(
a.x + (alpha.x * intr.x - intr.y * intr.y),
a.y + (alpha.x * intr.y + alpha.y * intr.x));
return c10::metal::mul(a, b);
}
};
@@ -324,9 +289,9 @@ REGISTER_BINARY_OP(add, float2, float2);
REGISTER_BINARY_OP(add, half2, half2);
REGISTER_BINARY_OP(sub, float2, float2);
REGISTER_BINARY_OP(sub, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2);

View File

@@ -278,11 +278,7 @@ static void add_sub_lerp_template(const Tensor& self,
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
if (alpha_has_value) {
at::native::alpha_check(commonDtype, alpha);
mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name,
self,
other,
output,
getMPSScalar(alpha, commonDtype));
mps::binary_op_kernel(op_name, self, other, output, getMPSScalar(alpha, commonDtype));
} else {
mps::binary_op_kernel(op_name, self, other, output);
}

View File

@@ -164,5 +164,25 @@ cast_to(const U from) {
return static_cast<T>(from);
}
// Generalizable math operators (used for both scalar and complex)
template <typename U, typename V>
using common_dtype = decltype(U(0) + V(0));
template <
typename T,
typename U,
::metal::enable_if_t<!is_complex_v<T>, bool> = true>
inline common_dtype<T, U> mul(const T x, const U y) {
return x * y;
}
template <
typename T,
typename U,
::metal::enable_if_t<is_complex_v<T> && is_complex_v<U>, bool> = true>
inline common_dtype<T, U> mul(const T x, const U y) {
return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x);
}
} // namespace metal
} // namespace c10