diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index fe82f554e0f..6d35a5e9b2a 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -137,11 +137,13 @@ class MetalShaderLibrary { void exec_unary_kernel( TensorIteratorBase& iter, const std::string& name, - std::optional extra = std::nullopt); + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); void exec_binary_kernel( TensorIteratorBase& iter, const std::string& name, - const std::optional alpha = std::nullopt); + const std::optional alpha = std::nullopt, + const std::optional scalar_arg_type = std::nullopt); protected: virtual MTLLibrary_t getLibrary(); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index c7d391a625b..583eb410345 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -990,11 +990,12 @@ void MetalShaderLibrary::bind_tensors(id encoder, Tens void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, const std::string& name, - std::optional extra) { + std::optional alpha, + std::optional scalar_arg_type) { // Decompose 64-bit tensor into 32-bit ones if (!iter.can_use_32bit_indexing()) { for (auto&& sub_iter : iter.with_32bit_indexing()) { - exec_unary_kernel(sub_iter, name, extra); + exec_unary_kernel(sub_iter, name, alpha, scalar_arg_type); } return; } @@ -1006,11 +1007,13 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, return; } using namespace mps; - auto kernel_name = fmt::format("{}_{}_{}_{}", + const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); + auto kernel_name = fmt::format("{}_{}_{}_{}{}", name, iter.is_contiguous() ? "dense" : "strided", scalarToMetalTypeString(outputTensor), - scalarToMetalTypeString(inputTensor)); + scalarToMetalTypeString(inputTensor), + alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""); @autoreleasepool { auto cplState = getPipelineStateForFunc(kernel_name); @@ -1029,8 +1032,8 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, outputTensor.strides(), inputTensor.ndimension()); } - if (extra) { - mtl_setBytes(computeEncoder, *extra, iter.is_contiguous() ? 2 : 6); + if (alpha) { + mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), iter.is_contiguous() ? 2 : 6); } mtl_dispatch1DJob(computeEncoder, cplState, length); @@ -1041,7 +1044,8 @@ void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std::string& name, - std::optional alpha) { + std::optional alpha, + std::optional scalar_arg_type) { // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with // double as common dtype (because Python floating point are always 64-bit values) @@ -1055,7 +1059,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, // Decompose 64-bit tensor into 32-bit ones if (!iter.can_use_32bit_indexing()) { for (auto&& sub_iter : iter.with_32bit_indexing()) { - exec_binary_kernel(sub_iter, name, alpha); + exec_binary_kernel(sub_iter, name, alpha, scalar_arg_type); } return; } @@ -1081,10 +1085,13 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, MPSStream* mpsStream = getCurrentMPSStream(); const auto cast_needed = input.scalar_type() != other.scalar_type(); const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); + const auto alpha_suffix = alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""; // TODO: Implicitly pass both input and output types to non-cast kernels const auto kernel_name = cast_needed - ? fmt::format("{}_{}_cast_{}", name, suffix, scalarToMetalTypeString(out)) - : fmt::format("{}_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input)); + ? fmt::format("{}_{}_cast_{}{}", name, suffix, scalarToMetalTypeString(out), alpha_suffix) + : fmt::format( + "{}_{}_{}_{}{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input), alpha_suffix); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { auto computeEncoder = mpsStream->commandEncoder(); @@ -1098,7 +1105,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, // i.e. it's true for both row-first and column-first tensors if (iter.is_contiguous()) { if (alpha) { - mtl_setBytes(computeEncoder, getMPSScalar(*alpha, iter.common_dtype()), 3); + mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), 3); } if (cast_needed) { std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), @@ -1117,7 +1124,7 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, static_cast(out.scalar_type())}; if (alpha) { mtl_setArgs<3>(computeEncoder, - getMPSScalar(*alpha, iter.common_dtype()), + getMPSScalar(*alpha, alpha_type), iter.shape(), iter.strides(0), iter.strides(1), diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal new file mode 100644 index 00000000000..08acd9c507a --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -0,0 +1,31 @@ +#include +#include +#include +using namespace metal; +using namespace c10::metal; + +struct hardshrink_functor { + template + inline T operator()(const T x, const T lambda) { + return (x >= -lambda && x <= lambda) ? T(0) : x; + } +}; + +struct hardshrink_backward_functor { + template + inline T operator()(const T grad_output, const T x, const T lambda) { + return (x >= -lambda && x <= lambda) ? T(0) : grad_output; + } +}; + +REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float); +REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat); +#endif + +REGISTER_BINARY_ALPHA_OP(hardshrink_backward, float, float, float); +REGISTER_BINARY_ALPHA_OP(hardshrink_backward, half, half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_BINARY_ALPHA_OP(hardshrink_backward, bfloat, bfloat, bfloat); +#endif diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 58809f32194..a1d748a5fa1 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -362,35 +362,35 @@ REGISTER_OPMATH_FLOAT_BINARY_OP(remainder); REGISTER_INTEGER_BINARY_OP(remainder); REGISTER_OPMATH_FLOAT_BINARY_OP(fmod); REGISTER_INTEGER_BINARY_OP(fmod); -REGISTER_BINARY_ALPHA_OP(add_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(add_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(add_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(add_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(add_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(add_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool); -REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool); +REGISTER_BINARY_ALPHA_OP(add_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(add_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(add_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(add_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(add_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(add_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool, bool); +REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool, bool); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long, long); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int, int); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float, float); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half, half); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short, short); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool); #if __METAL_VERSION__ >= 310 -REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat); -REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat); #endif // Complex binary functions @@ -406,9 +406,9 @@ REGISTER_BINARY_OP(add, float2, float2); REGISTER_BINARY_OP(add, half2, half2); REGISTER_BINARY_OP(sub, float2, float2); REGISTER_BINARY_OP(sub, half2, half2); -REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2); -REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2, half2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2, float2); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2, half2); diff --git a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal index 94b4d646fe5..638f0711d82 100644 --- a/aten/src/ATen/native/mps/kernels/UnaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/UnaryKernel.metal @@ -423,25 +423,25 @@ kernel void round_decimals_strided( rint(exp10(float(ndigits)) * input[input_offs]) * exp10(float(-ndigits))); } -#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \ - template \ - [[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE)]] kernel void \ - round_decimals_dense( \ - device DTYPE* output [[buffer(0)]], \ - constant DTYPE* input [[buffer(1)]], \ - constant long& ndigits [[buffer(2)]], \ - uint index [[thread_position_in_grid]]); \ - template \ - [[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE)]] kernel void \ - round_decimals_strided( \ - device DTYPE* output [[buffer(0)]], \ - constant DTYPE* input [[buffer(1)]], \ - constant long* sizes, \ - constant long* input_strides, \ - constant long* output_strides, \ - constant uint& ndim, \ - constant long& ndigits [[buffer(6)]], \ - uint index) +#define INSTANTIATE_ROUND_DECIMALS(DTYPE) \ + template [[host_name("round_decimals_dense_" #DTYPE "_" #DTYPE \ + "_long")]] kernel void \ + round_decimals_dense( \ + device DTYPE* output [[buffer(0)]], \ + constant DTYPE* input [[buffer(1)]], \ + constant long& ndigits [[buffer(2)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("round_decimals_strided_" #DTYPE "_" #DTYPE \ + "_long")]] kernel void \ + round_decimals_strided( \ + device DTYPE* output [[buffer(0)]], \ + constant DTYPE* input [[buffer(1)]], \ + constant long* sizes, \ + constant long* input_strides, \ + constant long* output_strides, \ + constant uint& ndim, \ + constant long& ndigits [[buffer(6)]], \ + uint index) INSTANTIATE_ROUND_DECIMALS(float); INSTANTIATE_ROUND_DECIMALS(half); diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm new file mode 100644 index 00000000000..e50003f56c4 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -0,0 +1,27 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +namespace at::native { + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif + +static void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { + lib.exec_unary_kernel(iter, "hardshrink", lambda); +} + +static void hardshrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { + lib.exec_binary_kernel(iter, "hardshrink_backward", lambda); +} + +REGISTER_DISPATCH(hardshrink_stub, hardshrink_kernel); +REGISTER_DISPATCH(shrink_backward_stub, hardshrink_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm index b8fbcea4534..45f311efba7 100644 --- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm @@ -21,7 +21,7 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); REGISTER_DISPATCH(NAME##_stub, NAME##_kernel_mps) static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) { - lib.exec_unary_kernel(iter, "round_decimals", decimals); + lib.exec_unary_kernel(iter, "round_decimals", Scalar(decimals), ScalarType::Long); } REGISTER_UNARY_TI_DISPATCH(exp); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9e53e00ddd6..f3422308fed 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5171,7 +5171,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: hardshrink_out + CPU, CUDA, MPS: hardshrink_out - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor structured_delegate: hardshrink.out @@ -5183,7 +5183,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hardshrink_backward_out + CPU, CUDA, MPS: hardshrink_backward_out - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor structured_delegate: hardshrink_backward.grad_input diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 9c6b5cf200e..b8d7d30077a 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -104,6 +104,61 @@ kernel void unary_strided( } \ } +template +kernel void unary_alpha_dense( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T2& alpha [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + F f; + output[index] = f(input[index], alpha); +} + +template +kernel void unary_alpha_strided( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* input_strides [[buffer(3)]], + constant long* output_strides [[buffer(4)]], + constant uint& ndim [[buffer(5)]], + constant T2& alpha [[buffer(6)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim); + const auto input_offs = offset_from_coord(pos, input_strides, ndim); + const auto output_offs = offset_from_coord(pos, output_strides, ndim); + output[output_offs] = f(input[input_offs], alpha); +} + +#define REGISTER_UNARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for unary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_dense( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant DTYPEA & alpha, \ + uint index); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + unary_alpha_strided( \ + device ::c10::metal::result_of * \ + output, \ + constant DTYPEI * input, \ + constant long* sizes, \ + constant long* input_strides, \ + constant long* output_strides, \ + constant uint& ndim, \ + constant DTYPEA& alpha, \ + uint index) + template inline T val_at_offs(constant void* ptr, long offs) { return *reinterpret_cast( @@ -191,12 +246,12 @@ kernel void binary_strided( static_cast(f(om_t(a), om_t(b))); } -template -kernel void alpha_binary_strided( +template +kernel void binary_alpha_strided( device void* output [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant long* sizes [[buffer(4)]], constant long* output_strides [[buffer(5)]], constant long* input_strides [[buffer(6)]], @@ -211,7 +266,7 @@ kernel void alpha_binary_strided( const auto output_offs = offset_from_coord(pos, output_strides, ndim.x); const auto a = val_at_offs(input, input_offs); const auto b = val_at_offs(other, other_offs); - ref_at_offs>(output, output_offs) = f(a, b, alpha); + ref_at_offs>(output, output_offs) = f(a, b, alpha); } template > @@ -239,12 +294,12 @@ kernel void binary_strided_cast( ref_at_offs(output, output_offs) = static_cast(f(a, b)); } -template -kernel void alpha_binary_strided_cast( +template +kernel void binary_alpha_strided_cast( device void* output [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant long* sizes [[buffer(4)]], constant long* output_strides [[buffer(5)]], constant long* input_strides [[buffer(6)]], @@ -261,7 +316,7 @@ kernel void alpha_binary_strided_cast( val_at_offs(input, input_offs, static_cast(ndim_types.y)); const auto b = val_at_offs(other, other_offs, static_cast(ndim_types.z)); - ref_at_offs>(output, output_offs) = f(a, b, alpha); + ref_at_offs>(output, output_offs) = f(a, b, alpha); } template > @@ -275,12 +330,12 @@ kernel void binary_dense( out[tid] = static_cast(f(om_t(input[tid]), om_t(other[tid]))); } -template -kernel void alpha_binary_dense( - device result_of* out [[buffer(0)]], +template +kernel void binary_alpha_dense( + device result_of* out [[buffer(0)]], constant T* input [[buffer(1)]], constant T* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], uint tid [[thread_position_in_grid]]) { F f; out[tid] = f(input[tid], other[tid], alpha); @@ -302,12 +357,12 @@ kernel void binary_dense_cast( out[tid] = static_cast(f(a, b)); } -template -kernel void alpha_binary_dense_cast( - device result_of* out [[buffer(0)]], +template +kernel void binary_alpha_dense_cast( + device result_of* out [[buffer(0)]], constant void* input [[buffer(1)]], constant void* other [[buffer(2)]], - constant T& alpha [[buffer(3)]], + constant T2& alpha [[buffer(3)]], constant uint4& sizes_types [[buffer(4)]], uint tid [[thread_position_in_grid]]) { F f; @@ -369,54 +424,58 @@ kernel void alpha_binary_dense_cast( #define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \ REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, DTYPEI) -#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \ +#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEA, DTYPEO) \ static_assert( \ ::metal::is_same_v< \ DTYPEO, \ - ::c10::metal::result_of>, \ + ::c10::metal::result_of>, \ "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \ - template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ - c10::metal::alpha_binary_strided( \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided( \ device void* out, \ constant void* input, \ constant void* other, \ - constant DTYPEI& alpha, \ + constant DTYPEA& alpha, \ constant long* sizes, \ constant long* output_strides, \ constant long* input_strides, \ constant long* other_strides, \ constant uint3& ndim, \ uint tid); \ - template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ - metal::alpha_binary_strided_cast( \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_strided_cast( \ device void* out, \ constant void* input, \ constant void* other, \ - constant DTYPEI& alpha, \ + constant DTYPEA& alpha, \ constant long* sizes, \ constant long* output_strides, \ constant long* input_strides, \ constant long* other_strides, \ constant uint4& ndim_types, \ uint tid); \ - template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ - c10::metal::alpha_binary_dense( \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI \ + "_" #DTYPEA)]] kernel void ::c10::metal:: \ + binary_alpha_dense( \ device ::c10::metal:: \ - result_of * \ + result_of * \ out_, \ constant DTYPEI * input_, \ constant DTYPEI * other_, \ - constant DTYPEI & alpha, \ + constant DTYPEA & alpha, \ uint tid); \ - template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ - metal::alpha_binary_dense_cast( \ - device ::c10::metal:: \ - result_of * \ - out_, \ - constant void* input, \ - constant void* other, \ - constant DTYPEI& alpha, \ - constant uint4& sizes_types, \ - uint tid) + template \ + [[host_name(#NAME "_dense_cast_" #DTYPEI "_" #DTYPEA)]] kernel void :: \ + c10::metal::binary_alpha_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other, \ + constant DTYPEA& alpha, \ + constant uint4& sizes_types, \ + uint tid) } // namespace metal } // namespace c10 diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index c071cd11e8f..924bb7962b2 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -4040,13 +4040,6 @@ module_db: list[ModuleInfo] = [ ), ModuleInfo(torch.nn.Hardshrink, module_inputs_func=module_inputs_torch_nn_Hardshrink, - skips=( - # not supported on MPS backend - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_if_train_and_eval_modes_differ', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_non_contiguous_tensors', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_save_load', device_type='mps'),), ), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 6f2d95643d0..d818fa4e7fb 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -437,7 +437,13 @@ if torch.backends.mps.is_available(): "nn.functional.avg_pool3d": None, "nn.functional.ctc_loss": None, "nn.functional.embedding_bag": None, - "nn.functional.hardshrink": None, + "nn.functional.hardshrink": [ + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ], "nn.functional.max_pool3d": None, "nn.functional.max_unpool1d": None, "nn.functional.max_unpool2d": None,