[Metal] Extend typecasted op support to complex dtypes (#152504)

First of all, by extending `c10::metal::cast_to` to work correctly with complex dtypes, by introducing two more specializations: one that casts complex to scalar, and another that casts scalar to complex (as default metal typecast will turn `float x` into `float2(x, x)`)

Add ComplexHalf and ComplexFloat enum values to `c10::metal::ScalarTypes` and handle them in `val_at_offs(ptr, offs, type)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152504
Approved by: https://github.com/dcci
ghstack dependencies: #152443, #152466, #152479
This commit is contained in:
Nikita Shulga
2025-04-29 22:28:25 -07:00
committed by PyTorch MergeBot
parent 4df97a8839
commit a2c553cac6
3 changed files with 48 additions and 18 deletions

View File

@@ -16,6 +16,8 @@
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11) \
_(BFloat16, 15)
#else
@@ -27,6 +29,8 @@
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11)
#endif

View File

@@ -115,26 +115,31 @@ template <typename T>
inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
switch (type) {
case ScalarType::Bool:
return val_at_offs<bool>(ptr, offs);
return cast_to<T>(val_at_offs<bool>(ptr, offs));
case ScalarType::Byte:
return val_at_offs<uchar>(ptr, offs);
return cast_to<T>(val_at_offs<uchar>(ptr, offs));
case ScalarType::Char:
return val_at_offs<char>(ptr, offs);
return cast_to<T>(val_at_offs<char>(ptr, offs));
case ScalarType::Short:
return val_at_offs<short>(ptr, offs);
return cast_to<T>(val_at_offs<short>(ptr, offs));
case ScalarType::Int:
return val_at_offs<int>(ptr, offs);
return cast_to<T>(val_at_offs<int>(ptr, offs));
case ScalarType::Long:
return val_at_offs<long>(ptr, offs);
return cast_to<T>(val_at_offs<long>(ptr, offs));
// Floats
case ScalarType::Float:
return static_cast<T>(val_at_offs<float>(ptr, offs));
return cast_to<T>(val_at_offs<float>(ptr, offs));
case ScalarType::Half:
return static_cast<T>(val_at_offs<half>(ptr, offs));
return cast_to<T>(val_at_offs<half>(ptr, offs));
#if __METAL_VERSION__ >= 310
case ScalarType::BFloat16:
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
#endif
// Complex
case ScalarType::ComplexHalf:
return cast_to<T>(val_at_offs<half2>(ptr, offs));
case ScalarType::ComplexFloat:
return cast_to<T>(val_at_offs<float2>(ptr, offs));
}
}

View File

@@ -148,20 +148,41 @@ template <typename T>
constexpr constant bool is_scalar_integral_v =
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
template <typename T, typename U>
inline ::metal::enable_if_t<::metal::is_same_v<U, T>, T> cast_to(const U from) {
// cast_to primitives
// - No-op if types as the same
template <
typename T,
typename U,
::metal::enable_if_t<::metal::is_same_v<U, T>, bool> = true>
inline T cast_to(const U from) {
return from;
}
template <typename T, typename U>
inline ::metal::enable_if_t<is_complex_v<T>, T> cast_to(const U from) {
return T(float(from), 0.0);
// - Simple cast between scalar and complex dtypes
template <
typename T,
typename U,
::metal::enable_if_t<
!::metal::is_same_v<U, T> && (is_complex_v<T> == is_complex_v<U>),
bool> = true>
inline T cast_to(const U from) {
return static_cast<T>(from);
}
template <typename T, typename U>
inline ::metal::enable_if_t<!::metal::is_same_v<U, T> && !is_complex_v<T>, T>
cast_to(const U from) {
return static_cast<T>(from);
// - Scalar to complex
template <
typename T,
typename U,
::metal::enable_if_t<is_complex_v<T> && !is_complex_v<U>, bool> = true>
inline T cast_to(const U from) {
return T(float(from), 0.0);
}
// - Complex to scalar (should not really be used, but exists for compliteness)
template <
typename T,
typename U,
::metal::enable_if_t<!is_complex_v<T> && is_complex_v<U>, bool> = true>
inline T cast_to(const U from) {
return static_cast<T>(from.x);
}
// Generalizable math operators (used for both scalar and complex)