mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[MPS] Implement hardshrink metal kernel (#155304)
Implements the forward and backward hardshrink operators as Metal kernels. In order to support the lambda parameter, we extend the `exec_unary_kernel` and `exec_binary_kernel` methods. Now they take an optional Scalar and an optional ScalarType argument. When the optional ScalarType is provided, it overrides the type of the Scalar. We add a new `REGISTER_UNARY_ALPHA_OP` macro, and modify the existing `REGISTER_BINARY_ALPHA_OP` to support the new feature. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155304 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
8347268edc
commit
0f47e76937
@@ -104,6 +104,61 @@ kernel void unary_strided(
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void unary_alpha_dense(
|
||||
device result_of<F, T, T2>* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant T2& alpha [[buffer(2)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
output[index] = f(input[index], alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void unary_alpha_strided(
|
||||
device result_of<F, T, T2>* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant long* sizes [[buffer(2)]],
|
||||
constant long* input_strides [[buffer(3)]],
|
||||
constant long* output_strides [[buffer(4)]],
|
||||
constant uint& ndim [[buffer(5)]],
|
||||
constant T2& alpha [[buffer(6)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(index), pos, sizes, ndim);
|
||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim);
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
|
||||
output[output_offs] = f(input[input_offs], alpha);
|
||||
}
|
||||
|
||||
#define REGISTER_UNARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \
|
||||
static_assert( \
|
||||
::metal::is_same_v< \
|
||||
DTYPEO, \
|
||||
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA>>, \
|
||||
"Output dtype mismatch for unary op " #NAME " and input " #DTYPEI); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \
|
||||
"_" #DTYPEA)]] kernel void ::c10::metal:: \
|
||||
unary_alpha_dense<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA> * \
|
||||
output, \
|
||||
constant DTYPEI * input, \
|
||||
constant DTYPEA & alpha, \
|
||||
uint index); \
|
||||
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \
|
||||
"_" #DTYPEA)]] kernel void ::c10::metal:: \
|
||||
unary_alpha_strided<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device ::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEA> * \
|
||||
output, \
|
||||
constant DTYPEI * input, \
|
||||
constant long* sizes, \
|
||||
constant long* input_strides, \
|
||||
constant long* output_strides, \
|
||||
constant uint& ndim, \
|
||||
constant DTYPEA& alpha, \
|
||||
uint index)
|
||||
|
||||
template <typename T>
|
||||
inline T val_at_offs(constant void* ptr, long offs) {
|
||||
return *reinterpret_cast<constant T*>(
|
||||
@@ -191,12 +246,12 @@ kernel void binary_strided(
|
||||
static_cast<res_t>(f(om_t(a), om_t(b)));
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_strided(
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void binary_alpha_strided(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T& alpha [[buffer(3)]],
|
||||
constant T2& alpha [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
@@ -211,7 +266,7 @@ kernel void alpha_binary_strided(
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
|
||||
const auto a = val_at_offs<T>(input, input_offs);
|
||||
const auto b = val_at_offs<T>(other, other_offs);
|
||||
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, alpha);
|
||||
ref_at_offs<result_of<F, T, T, T2>>(output, output_offs) = f(a, b, alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename om_t = opmath_t<T>>
|
||||
@@ -239,12 +294,12 @@ kernel void binary_strided_cast(
|
||||
ref_at_offs<res_t>(output, output_offs) = static_cast<res_t>(f(a, b));
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_strided_cast(
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void binary_alpha_strided_cast(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T& alpha [[buffer(3)]],
|
||||
constant T2& alpha [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
@@ -261,7 +316,7 @@ kernel void alpha_binary_strided_cast(
|
||||
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
|
||||
const auto b =
|
||||
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
|
||||
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, alpha);
|
||||
ref_at_offs<result_of<F, T, T, T2>>(output, output_offs) = f(a, b, alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename F, typename om_t = opmath_t<T>>
|
||||
@@ -275,12 +330,12 @@ kernel void binary_dense(
|
||||
out[tid] = static_cast<res_t>(f(om_t(input[tid]), om_t(other[tid])));
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_dense(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void binary_alpha_dense(
|
||||
device result_of<F, T, T, T2>* out [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant T* other [[buffer(2)]],
|
||||
constant T& alpha [[buffer(3)]],
|
||||
constant T2& alpha [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
out[tid] = f(input[tid], other[tid], alpha);
|
||||
@@ -302,12 +357,12 @@ kernel void binary_dense_cast(
|
||||
out[tid] = static_cast<res_t>(f(a, b));
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_dense_cast(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
template <typename T, typename T2, typename F>
|
||||
kernel void binary_alpha_dense_cast(
|
||||
device result_of<F, T, T, T2>* out [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T& alpha [[buffer(3)]],
|
||||
constant T2& alpha [[buffer(3)]],
|
||||
constant uint4& sizes_types [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
@@ -369,54 +424,58 @@ kernel void alpha_binary_dense_cast(
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \
|
||||
REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI)
|
||||
|
||||
#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \
|
||||
#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \
|
||||
static_assert( \
|
||||
::metal::is_same_v< \
|
||||
DTYPEO, \
|
||||
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
|
||||
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA>>, \
|
||||
"Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \
|
||||
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::alpha_binary_strided<DTYPEI, NAME##_functor>( \
|
||||
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \
|
||||
"_" #DTYPEA)]] kernel void ::c10::metal:: \
|
||||
binary_alpha_strided<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI& alpha, \
|
||||
constant DTYPEA& alpha, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other_strides, \
|
||||
constant uint3& ndim, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::alpha_binary_strided_cast<DTYPEI, NAME##_functor>( \
|
||||
template [[host_name(#NAME "_strided_cast_" #DTYPEI \
|
||||
"_" #DTYPEA)]] kernel void ::c10::metal:: \
|
||||
binary_alpha_strided_cast<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI& alpha, \
|
||||
constant DTYPEA& alpha, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other_strides, \
|
||||
constant uint4& ndim_types, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::alpha_binary_dense<DTYPEI, NAME##_functor>( \
|
||||
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \
|
||||
"_" #DTYPEA)]] kernel void ::c10::metal:: \
|
||||
binary_alpha_dense<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
|
||||
out_, \
|
||||
constant DTYPEI * input_, \
|
||||
constant DTYPEI * other_, \
|
||||
constant DTYPEI & alpha, \
|
||||
constant DTYPEA & alpha, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::alpha_binary_dense_cast<DTYPEI, NAME##_functor>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
out_, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI& alpha, \
|
||||
constant uint4& sizes_types, \
|
||||
uint tid)
|
||||
template \
|
||||
[[host_name(#NAME "_dense_cast_" #DTYPEI "_" #DTYPEA)]] kernel void :: \
|
||||
c10::metal::binary_alpha_dense_cast<DTYPEI, DTYPEA, NAME##_functor>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEA> * \
|
||||
out_, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEA& alpha, \
|
||||
constant uint4& sizes_types, \
|
||||
uint tid)
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
Reference in New Issue
Block a user