diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index 56f1ca63308..e61f33dc544 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -5,6 +5,29 @@ using namespace metal; using namespace c10::metal; +namespace c10 { +namespace metal { +// There are no atomic 64-bit add in Metal yet, but this implements a consistent +// add I.e. if multiple threads are modify the same 64-bit value, results stored +// at the address will eventually be equal to its original value plus sum of all +// operands +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, long value) { + const auto value_bits = as_type(value); + const uint low = static_cast(value_bits); + uint high = static_cast(value_bits >> 32); + auto ptr = data + (offset << 1); + auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed); + high += (old_low + low < old_low) ? 1 : 0; + atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed); + } +}; + +} // namespace metal +} // namespace c10 + struct IndexAB { constant int64_t* indexArray; }; @@ -211,7 +234,11 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial); REGISTER_INDEX_OP(put_accumulate, float, float); REGISTER_INDEX_OP(put_accumulate, half, half); +REGISTER_INDEX_OP(put_accumulate, long, long); REGISTER_INDEX_OP(put_accumulate, int, int); +REGISTER_INDEX_OP(put_accumulate, short, short); +REGISTER_INDEX_OP(put_accumulate, char, char); +REGISTER_INDEX_OP(put_accumulate, uchar, uchar); REGISTER_INDEX_OP(put_accumulate, bool, bool); #if __METAL_VERSION__ >= 310 REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat); diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index f6187fbaec3..f00d155559d 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -121,8 +121,8 @@ static void validateInputData(const TensorIteratorBase& iter, const auto scalar_type = inputTensor.scalar_type(); if (accumulate) { - // No atomic support for the rest of dtypes - TORCH_CHECK(supportedFloatingType(scalar_type) || scalar_type == kInt || scalar_type == kBool); + // No atomic support for the complex dtypes + TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type)); } else { TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) || scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf, diff --git a/c10/metal/atomic.h b/c10/metal/atomic.h index 84698024e88..b2bce2b8b5f 100644 --- a/c10/metal/atomic.h +++ b/c10/metal/atomic.h @@ -35,15 +35,16 @@ static inline void atomic_add_helper( device ::metal::atomic* data, long offset, T value) { - auto ptr = data + (offset >> 1); + constexpr auto elem_per_enum = sizeof(uint) / sizeof(T); + auto ptr = data + (offset / elem_per_enum); auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed); union { uint i; - T t[2]; + T t[elem_per_enum]; } val; do { val.i = old; - val.t[offset & 1] += value; + val.t[offset & (elem_per_enum - 1)] += value; } while (!::metal::atomic_compare_exchange_weak_explicit( ptr, &old, @@ -56,7 +57,31 @@ template <> struct AtomicType { using type = ::metal::atomic; static inline void atomic_add(device type* data, long offset, half value) { - atomic_add_helper(data, offset, value); + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, short value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, char value) { + atomic_add_helper(data, offset, value); + } +}; + +template <> +struct AtomicType { + using type = ::metal::atomic; + static inline void atomic_add(device type* data, long offset, char value) { + atomic_add_helper(data, offset, value); } }; diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index ce10aaa9489..8d1d72c691e 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -541,13 +541,6 @@ if torch.backends.mps.is_available(): # round not working properly for float16 and bfloat16 "round": [torch.float16, torch.bfloat16], "rounddecimals_0": [torch.bfloat16], - # atomic operations not supported - "_unsafe_masked_index_put_accumulate": [ - torch.int8, - torch.uint8, - torch.int16, - torch.int64, - ], } if MACOS_VERSION < 14.0: @@ -642,12 +635,6 @@ if torch.backends.mps.is_available(): torch.float16, torch.bfloat16, ], - "index_put": [ - torch.uint8, - torch.int8, - torch.int16, - torch.int64, - ], # zero to negative integer powers are undefined "__rpow__": [torch.int8, torch.int16, torch.int32, torch.int64], "resize_": [torch.float16, torch.float32, torch.bfloat16],