mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[MPS] Migrate clamp.Tensor_out to Metal (#169407)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/169407 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
abfa1a6d65
commit
8c73bbbb02
@@ -475,5 +475,177 @@ kernel void binary_alpha_dense_cast(
|
||||
constant DTYPEA& alpha, \
|
||||
constant uint4& sizes_types, \
|
||||
uint tid)
|
||||
|
||||
// Ternary elementwise ops kernels
|
||||
// Right now there are 4 flavors available:
|
||||
// - ternary_dense where both input, other1, other2, and output are dense and
|
||||
// share the same type
|
||||
// - ternary_strided when all inputs are of the same types, but some elements
|
||||
// are strided
|
||||
// - ternary_dense_cast - inputs are dense, but of different dtypes
|
||||
// - ternary_strided_cast - inputs or output are strided and of different dtypes
|
||||
// Note about accuracy (for more info see
|
||||
// https://github.com/pytorch/pytorch/issues/152736) Sometimes when kernel is
|
||||
// invoked to produce `half` output, but one of the arguments is float arguments
|
||||
// should be upcast to float, rather than downcast to half At the moment this is
|
||||
// expressed with `om_t` optional argument (which stands for opmath_type) which
|
||||
// is identical to output type but could be something else
|
||||
|
||||
template <typename T, typename F, typename om_t = T>
|
||||
kernel void ternary_strided(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other1 [[buffer(2)]],
|
||||
constant void* other2 [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
constant long* other1_strides [[buffer(7)]],
|
||||
constant long* other2_strides [[buffer(8)]],
|
||||
constant uint& ndim [[buffer(9)]],
|
||||
constant uint4& types [[buffer(10)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
using res_t = result_of<F, T, T, T>;
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(index), pos, sizes, ndim);
|
||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||
const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
|
||||
const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||
const auto a = val_at_offs<T>(input, input_offs);
|
||||
const auto b = val_at_offs<T>(other1, other1_offs);
|
||||
const auto c = val_at_offs<T>(other2, other2_offs);
|
||||
ref_at_offs<res_t>(output, output_offs) =
|
||||
static_cast<res_t>(f(om_t(a), om_t(b), om_t(c)));
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename om_t = opmath_t<T>>
|
||||
kernel void ternary_strided_cast(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other1 [[buffer(2)]],
|
||||
constant void* other2 [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
constant long* other1_strides [[buffer(7)]],
|
||||
constant long* other2_strides [[buffer(8)]],
|
||||
constant uint& ndim [[buffer(9)]],
|
||||
constant uint4& types [[buffer(10)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
using res_t = result_of<F, T, T, T>;
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(index), pos, sizes, ndim);
|
||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||
const auto other1_offs = offset_from_coord(pos, other1_strides, ndim);
|
||||
const auto other2_offs = offset_from_coord(pos, other2_strides, ndim);
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||
const auto a =
|
||||
val_at_offs<om_t>(input, input_offs, static_cast<ScalarType>(types.x));
|
||||
const auto b =
|
||||
val_at_offs<om_t>(other1, other1_offs, static_cast<ScalarType>(types.y));
|
||||
const auto c =
|
||||
val_at_offs<om_t>(other2, other2_offs, static_cast<ScalarType>(types.z));
|
||||
ref_at_offs<res_t>(output, output_offs) = static_cast<res_t>(f(a, b, c));
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename om_t = opmath_t<T>>
|
||||
kernel void ternary_dense(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant T* other1 [[buffer(2)]],
|
||||
constant T* other2 [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
using res_t = result_of<F, T, T, T>;
|
||||
out[tid] = static_cast<res_t>(
|
||||
f(om_t(input[tid]), om_t(other1[tid]), om_t(other2[tid])));
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename om_t = T>
|
||||
kernel void ternary_dense_cast(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other1 [[buffer(2)]],
|
||||
constant void* other2 [[buffer(3)]],
|
||||
constant uint3& sizes [[buffer(4)]],
|
||||
constant uint3& types [[buffer(5)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
using res_t = result_of<F, T, T, T>;
|
||||
const auto a =
|
||||
val_at_offs<om_t>(input, tid * sizes.x, static_cast<ScalarType>(types.x));
|
||||
const auto b = val_at_offs<om_t>(
|
||||
other1, tid * sizes.y, static_cast<ScalarType>(types.y));
|
||||
const auto c = val_at_offs<om_t>(
|
||||
other2, tid * sizes.z, static_cast<ScalarType>(types.z));
|
||||
out[tid] = static_cast<res_t>(f(a, b, c));
|
||||
}
|
||||
|
||||
#define REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \
|
||||
static_assert( \
|
||||
::metal::is_same_v< \
|
||||
DTYPEO, \
|
||||
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
|
||||
"Output dtype mismatch for ternary op " #NAME " and input " #DTYPEI); \
|
||||
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::ternary_strided<DTYPEI, NAME##_functor, OMT>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other1, \
|
||||
constant void* other2, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other1_strides, \
|
||||
constant long* other2_strides, \
|
||||
constant uint& ndim, \
|
||||
constant uint4& types, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::ternary_strided_cast<DTYPEI, NAME##_functor, OMT>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other1, \
|
||||
constant void* other2, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other1_strides, \
|
||||
constant long* other2_strides, \
|
||||
constant uint& ndim, \
|
||||
constant uint4& types, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::ternary_dense<DTYPEI, NAME##_functor, OMT>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
out_, \
|
||||
constant DTYPEI * input_, \
|
||||
constant DTYPEI * other1_, \
|
||||
constant DTYPEI * other2_, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::ternary_dense_cast<DTYPEI, NAME##_functor, OMT>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
out_, \
|
||||
constant void* input, \
|
||||
constant void* other1, \
|
||||
constant void* other2, \
|
||||
constant uint3& sizes, \
|
||||
constant uint3& types, \
|
||||
uint tid)
|
||||
|
||||
// OpMath ternary Op promotes inputs to higher precision type before Functor
|
||||
// call
|
||||
#define REGISTER_OPMATH_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
|
||||
REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, ::c10::metal::opmath_t<DTYPEI>)
|
||||
|
||||
#define REGISTER_TERNARY_OP(NAME, DTYPEI, DTYPEO) \
|
||||
REGISTER_TERNARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI)
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
Reference in New Issue
Block a user