diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index c7eaa802af1..c5dbf05039e 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return minimum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -199,7 +200,7 @@ void aminmax_allreduce_kernel( } ); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] { using Vec = Vectorized>; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( @@ -214,7 +215,7 @@ void aminmax_allreduce_kernel( [=](Vec a, Vec b) -> Vec { return minimum(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); } } diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu index e01ca6c88eb..0006a24dbc4 100644 --- a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinOps{}, thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda)