mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[Metal] Extend typecasted op support to complex dtypes (#152504)
First of all, by extending `c10::metal::cast_to` to work correctly with complex dtypes, by introducing two more specializations: one that casts complex to scalar, and another that casts scalar to complex (as default metal typecast will turn `float x` into `float2(x, x)`) Add ComplexHalf and ComplexFloat enum values to `c10::metal::ScalarTypes` and handle them in `val_at_offs(ptr, offs, type)` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152504 Approved by: https://github.com/dcci ghstack dependencies: #152443, #152466, #152479
This commit is contained in:
committed by
PyTorch MergeBot
parent
4df97a8839
commit
a2c553cac6
@@ -16,6 +16,8 @@
|
||||
_(Long, 4) \
|
||||
_(Half, 5) \
|
||||
_(Float, 6) \
|
||||
_(ComplexHalf, 8) \
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11) \
|
||||
_(BFloat16, 15)
|
||||
#else
|
||||
@@ -27,6 +29,8 @@
|
||||
_(Long, 4) \
|
||||
_(Half, 5) \
|
||||
_(Float, 6) \
|
||||
_(ComplexHalf, 8) \
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11)
|
||||
#endif
|
||||
|
||||
|
||||
@@ -115,26 +115,31 @@ template <typename T>
|
||||
inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
|
||||
switch (type) {
|
||||
case ScalarType::Bool:
|
||||
return val_at_offs<bool>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<bool>(ptr, offs));
|
||||
case ScalarType::Byte:
|
||||
return val_at_offs<uchar>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<uchar>(ptr, offs));
|
||||
case ScalarType::Char:
|
||||
return val_at_offs<char>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<char>(ptr, offs));
|
||||
case ScalarType::Short:
|
||||
return val_at_offs<short>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<short>(ptr, offs));
|
||||
case ScalarType::Int:
|
||||
return val_at_offs<int>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<int>(ptr, offs));
|
||||
case ScalarType::Long:
|
||||
return val_at_offs<long>(ptr, offs);
|
||||
return cast_to<T>(val_at_offs<long>(ptr, offs));
|
||||
// Floats
|
||||
case ScalarType::Float:
|
||||
return static_cast<T>(val_at_offs<float>(ptr, offs));
|
||||
return cast_to<T>(val_at_offs<float>(ptr, offs));
|
||||
case ScalarType::Half:
|
||||
return static_cast<T>(val_at_offs<half>(ptr, offs));
|
||||
return cast_to<T>(val_at_offs<half>(ptr, offs));
|
||||
#if __METAL_VERSION__ >= 310
|
||||
case ScalarType::BFloat16:
|
||||
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
|
||||
#endif
|
||||
// Complex
|
||||
case ScalarType::ComplexHalf:
|
||||
return cast_to<T>(val_at_offs<half2>(ptr, offs));
|
||||
case ScalarType::ComplexFloat:
|
||||
return cast_to<T>(val_at_offs<float2>(ptr, offs));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -148,20 +148,41 @@ template <typename T>
|
||||
constexpr constant bool is_scalar_integral_v =
|
||||
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
|
||||
|
||||
template <typename T, typename U>
|
||||
inline ::metal::enable_if_t<::metal::is_same_v<U, T>, T> cast_to(const U from) {
|
||||
// cast_to primitives
|
||||
// - No-op if types as the same
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<::metal::is_same_v<U, T>, bool> = true>
|
||||
inline T cast_to(const U from) {
|
||||
return from;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline ::metal::enable_if_t<is_complex_v<T>, T> cast_to(const U from) {
|
||||
return T(float(from), 0.0);
|
||||
// - Simple cast between scalar and complex dtypes
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
!::metal::is_same_v<U, T> && (is_complex_v<T> == is_complex_v<U>),
|
||||
bool> = true>
|
||||
inline T cast_to(const U from) {
|
||||
return static_cast<T>(from);
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<U, T> && !is_complex_v<T>, T>
|
||||
cast_to(const U from) {
|
||||
return static_cast<T>(from);
|
||||
// - Scalar to complex
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<is_complex_v<T> && !is_complex_v<U>, bool> = true>
|
||||
inline T cast_to(const U from) {
|
||||
return T(float(from), 0.0);
|
||||
}
|
||||
// - Complex to scalar (should not really be used, but exists for compliteness)
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<!is_complex_v<T> && is_complex_v<U>, bool> = true>
|
||||
inline T cast_to(const U from) {
|
||||
return static_cast<T>(from.x);
|
||||
}
|
||||
|
||||
// Generalizable math operators (used for both scalar and complex)
|
||||
|
||||
Reference in New Issue
Block a user