mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[MPS][BE] Introduce c10::metal::mul (#152466)
Which multiplies two arguments for either scalar or complex data types This allows one to get rid of bunch of complex specialization in BinaryOps Pull Request resolved: https://github.com/pytorch/pytorch/pull/152466 Approved by: https://github.com/dcci ghstack dependencies: #152443
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee2d104c05
commit
9bfdf57572
@@ -21,21 +21,21 @@ struct sub_functor {
|
||||
struct add_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a + (alpha * b));
|
||||
return static_cast<T>(a + c10::metal::mul(alpha, b));
|
||||
}
|
||||
};
|
||||
|
||||
struct sub_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a - (alpha * b));
|
||||
return static_cast<T>(a - c10::metal::mul(alpha, b));
|
||||
}
|
||||
};
|
||||
|
||||
struct lerp_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a + (alpha * (b - a)));
|
||||
return static_cast<T>(a + c10::metal::mul(alpha, b - a));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -183,42 +183,7 @@ struct make_complex_functor {
|
||||
struct complex_mul_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_add_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return T(
|
||||
a.x + (alpha.x * b.x - alpha.y * b.y),
|
||||
a.y + (alpha.x * b.y + alpha.y * b.x));
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_add_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(a.x + b.x, a.y + b.y);
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_sub_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return T(
|
||||
a.x - (alpha.x * b.x - alpha.y * b.y),
|
||||
a.y - (alpha.x * b.y + alpha.y * b.x));
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_lerp_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
auto intr = T(b.x - a.x, b.y - a.y);
|
||||
return T(
|
||||
a.x + (alpha.x * intr.x - intr.y * intr.y),
|
||||
a.y + (alpha.x * intr.y + alpha.y * intr.x));
|
||||
return c10::metal::mul(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -324,9 +289,9 @@ REGISTER_BINARY_OP(add, float2, float2);
|
||||
REGISTER_BINARY_OP(add, half2, half2);
|
||||
REGISTER_BINARY_OP(sub, float2, float2);
|
||||
REGISTER_BINARY_OP(sub, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2);
|
||||
|
||||
@@ -278,11 +278,7 @@ static void add_sub_lerp_template(const Tensor& self,
|
||||
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
|
||||
if (alpha_has_value) {
|
||||
at::native::alpha_check(commonDtype, alpha);
|
||||
mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name,
|
||||
self,
|
||||
other,
|
||||
output,
|
||||
getMPSScalar(alpha, commonDtype));
|
||||
mps::binary_op_kernel(op_name, self, other, output, getMPSScalar(alpha, commonDtype));
|
||||
} else {
|
||||
mps::binary_op_kernel(op_name, self, other, output);
|
||||
}
|
||||
|
||||
@@ -164,5 +164,25 @@ cast_to(const U from) {
|
||||
return static_cast<T>(from);
|
||||
}
|
||||
|
||||
// Generalizable math operators (used for both scalar and complex)
|
||||
template <typename U, typename V>
|
||||
using common_dtype = decltype(U(0) + V(0));
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<!is_complex_v<T>, bool> = true>
|
||||
inline common_dtype<T, U> mul(const T x, const U y) {
|
||||
return x * y;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<is_complex_v<T> && is_complex_v<U>, bool> = true>
|
||||
inline common_dtype<T, U> mul(const T x, const U y) {
|
||||
return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
Reference in New Issue
Block a user