[MPS] Migrate clamp.Tensor_out to Metal (#169407)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/169407
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-12-03 11:56:36 -06:00
committed by PyTorch MergeBot
parent abfa1a6d65
commit 8c73bbbb02
8 changed files with 329 additions and 6 deletions

View File

@@ -475,5 +475,177 @@ kernel void binary_alpha_dense_cast(
constant DTYPEA& alpha, \
constant uint4& sizes_types, \
uint tid)
// Ternary elementwise ops kernels
// Right now there are 4 flavors available:
// - ternary_dense where both input, other1, other2, and output are dense and
// share the same type
// - ternary_strided when all inputs are of the same types, but some elements
// are strided
// - ternary_dense_cast - inputs are dense, but of different dtypes
// - ternary_strided_cast - inputs or output are strided and of different dtypes
// Note about accuracy (for more info see
// https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is
// invoked to produce `half` output, but one of the arguments is float arguments
// should be upcast to float, rather than downcast to half At the moment this is
// expressed with `om_t` optional argument (which stands for opmath_type) which
// is identical to output type but could be something else
template <typename T, typename F, typename om_t = T>
kernel void ternary_strided(
device void* output [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other1 [[buffer(2)]],
constant void* other2 [[buffer(3)]],
constant long* sizes [[buffer(4)]],
constant long* output_strides [[buffer(5)]],
constant long* input_strides [[buffer(6)]],
constant long* other1_strides [[buffer(7)]],
constant long* other2_strides [[buffer(8)]],
constant uint& ndim [[buffer(9)]],
constant uint4& types [[buffer(10)]],
uint index [[thread_position_in_grid]]) {
F f;
using res_t = result_of<F, T, T, T>;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
const auto a = val_at_offs<T>(input, input_offs);
const auto b = val_at_offs<T>(other1, other1_offs);
const auto c = val_at_offs<T>(other2, other2_offs);
ref_at_offs<res_t>(output, output_offs) =
static_cast<res_t>(f(om_t(a), om_t(b), om_t(c)));
}
template <typename T, typename F, typename om_t = opmath_t<T>>
kernel void ternary_strided_cast(
device void* output [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other1 [[buffer(2)]],
constant void* other2 [[buffer(3)]],
constant long* sizes [[buffer(4)]],
constant long* output_strides [[buffer(5)]],
constant long* input_strides [[buffer(6)]],
constant long* other1_strides [[buffer(7)]],
constant long* other2_strides [[buffer(8)]],
constant uint& ndim [[buffer(9)]],
constant uint4& types [[buffer(10)]],
uint index [[thread_position_in_grid]]) {
F f;
using res_t = result_of<F, T, T, T>;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim);
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
const auto a =
val_at_offs<om_t>(input, input_offs, static_cast<ScalarType>(types.x));
const auto b =
val_at_offs<om_t>(other1, other1_offs, static_cast<ScalarType>(types.y));
const auto c =
val_at_offs<om_t>(other2, other2_offs, static_cast<ScalarType>(types.z));
ref_at_offs<res_t>(output, output_offs) = static_cast<res_t>(f(a, b, c));
}
template <typename T, typename F, typename om_t = opmath_t<T>>
kernel void ternary_dense(
device result_of<F, T, T, T>* out [[buffer(0)]],
constant T* input [[buffer(1)]],
constant T* other1 [[buffer(2)]],
constant T* other2 [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
F f;
using res_t = result_of<F, T, T, T>;
out[tid] = static_cast<res_t>(
f(om_t(input[tid]), om_t(other1[tid]), om_t(other2[tid])));
}
template <typename T, typename F, typename om_t = T>
kernel void ternary_dense_cast(
device result_of<F, T, T, T>* out [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other1 [[buffer(2)]],
constant void* other2 [[buffer(3)]],
constant uint3& sizes [[buffer(4)]],
constant uint3& types [[buffer(5)]],
uint tid [[thread_position_in_grid]]) {
F f;
using res_t = result_of<F, T, T, T>;
const auto a =
val_at_offs<om_t>(input, tid * sizes.x, static_cast<ScalarType>(types.x));
const auto b = val_at_offs<om_t>(
other1, tid * sizes.y, static_cast<ScalarType>(types.y));
const auto c = val_at_offs<om_t>(
other2, tid * sizes.z, static_cast<ScalarType>(types.z));
out[tid] = static_cast<res_t>(f(a, b, c));
}
#define REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \
static_assert( \
::metal::is_same_v< \
DTYPEO, \
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
"Output dtype mismatch for ternary op " #NAME " and input " #DTYPEI); \
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
c10::metal::ternary_strided<DTYPEI, NAME##_functor, OMT>( \
device void* out, \
constant void* input, \
constant void* other1, \
constant void* other2, \
constant long* sizes, \
constant long* output_strides, \
constant long* input_strides, \
constant long* other1_strides, \
constant long* other2_strides, \
constant uint& ndim, \
constant uint4& types, \
uint tid); \
template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
metal::ternary_strided_cast<DTYPEI, NAME##_functor, OMT>( \
device void* out, \
constant void* input, \
constant void* other1, \
constant void* other2, \
constant long* sizes, \
constant long* output_strides, \
constant long* input_strides, \
constant long* other1_strides, \
constant long* other2_strides, \
constant uint& ndim, \
constant uint4& types, \
uint tid); \
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
c10::metal::ternary_dense<DTYPEI, NAME##_functor, OMT>( \
device ::c10::metal:: \
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
out_, \
constant DTYPEI * input_, \
constant DTYPEI * other1_, \
constant DTYPEI * other2_, \
uint tid); \
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
metal::ternary_dense_cast<DTYPEI, NAME##_functor, OMT>( \
device ::c10::metal:: \
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
out_, \
constant void* input, \
constant void* other1, \
constant void* other2, \
constant uint3& sizes, \
constant uint3& types, \
uint tid)
// OpMath ternary Op promotes inputs to higher precision type before Functor
// call
#define REGISTER_OPMATH_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t<DTYPEI>)
#define REGISTER_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI)
} // namespace metal
} // namespace c10