diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 65115fb34a7..774a8283100 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -246,6 +246,20 @@ struct div_trunc_functor { } }; +struct remainder_functor { + template + inline T operator()(const T a, const T b) { + return T(a - b * c10::metal::floor_divide(a, b)); + } +}; + +struct fmod_functor { + template + inline T operator()(const T a, const T b) { + return c10::metal::fmod(a, b); + } +}; + // Some helper defines #if __METAL_VERSION__ >= 310 #define _METAL_310_PLUS(x) x @@ -304,6 +318,10 @@ REGISTER_FLOAT_BINARY_OP(div_trunc); REGISTER_INTEGER_BINARY_OP(div_trunc); REGISTER_OPMATH_FLOAT_BINARY_OP(div_true); REGISTER_INT2FLOAT_BINARY_OP(div_true); +REGISTER_OPMATH_FLOAT_BINARY_OP(remainder); +REGISTER_INTEGER_BINARY_OP(remainder); +REGISTER_OPMATH_FLOAT_BINARY_OP(fmod); +REGISTER_INTEGER_BINARY_OP(fmod); REGISTER_BINARY_ALPHA_OP(add_alpha, long, long); REGISTER_BINARY_ALPHA_OP(add_alpha, int, int); REGISTER_BINARY_ALPHA_OP(add_alpha, float, float); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index e4bc52cf6d2..0cbdf7132c7 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -159,6 +159,14 @@ static void div_trunc_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "div_trunc"); } +static void remainder_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "remainder"); +} + +static void fmod_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "fmod"); +} + REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel) REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel) REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) @@ -178,4 +186,6 @@ REGISTER_DISPATCH(mul_stub, &mul_mps_kernel) REGISTER_DISPATCH(div_true_stub, &div_true_mps_kernel) REGISTER_DISPATCH(div_floor_stub, &div_floor_mps_kernel) REGISTER_DISPATCH(div_trunc_stub, &div_trunc_mps_kernel) +REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel) +REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel) } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 2249223343d..a9589ecc490 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -30,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -184,61 +182,6 @@ static void binaryOpScalar(const Tensor& self, binaryOpTensor(self, wrapped_scalar_tensor(other), output, op_name, binaryBlock); } -static void div_mode_template(const Tensor& self, - const Tensor& other, - std::optional rounding_mode, - const Tensor& output, - const std::string& op_name) { - if (rounding_mode.has_value() && *rounding_mode == "trunc") { - TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input"); - } - BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; - if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { - primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"]; - secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor - toType:MPSDataTypeFloat32 - name:@"secondaryCastTensor"]; - } - MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryCastTensor - name:nil]; - // Rounding is a no-op for integral types, and also a reasonable workaround - // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` - // See https://github.com/pytorch/pytorch/issues/84995 - bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0; - if (!rounding_mode.has_value() || !isFloatOutput) { - return divTensor; - } else if (*rounding_mode == "trunc") { - auto truncTensor = [mpsGraph truncateWithTensor:divTensor name:nil]; - if (op_name == "fmod_mps_out") { - auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil]; - } - return truncTensor; - } else if (*rounding_mode == "floor") { - MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; - if (op_name == "remainder_out_mps") { - auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil]; - } - return floorTensor; - } - assert(0 && "Invalid rounding mode\n"); - return nullptr; - }; - binaryOpTensor(self, - other, - output, - op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""), - div_mode_op_block); -} - static void add_sub_lerp_template(const Tensor& self, const Tensor& other, const Scalar& alpha, @@ -352,14 +295,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const } } -TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::div_mode_template(self, other, "floor", output, "remainder_out_mps"); -} - -TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out"); -} - TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8ba40757482..464864a99f3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9862,8 +9862,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: fmod_out - MPS: fmod_mps_out + CPU, CUDA, MPS: fmod_out tags: pointwise - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor @@ -9969,8 +9968,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: remainder_out - MPS: remainder_out_mps + CPU, CUDA, MPS: remainder_out tags: pointwise - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor diff --git a/c10/metal/utils.h b/c10/metal/utils.h index ae2a99292e7..92fc87c4240 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -148,6 +148,9 @@ template constexpr constant bool is_scalar_integral_v = ::metal::is_integral_v && ::metal::is_scalar_v; +template +using common_dtype = decltype(U(0) + V(0)); + // floor_divide template < typename T, @@ -155,10 +158,42 @@ template < ::metal::enable_if_t< is_scalar_integral_v && is_scalar_integral_v, bool> = true> -inline decltype(T(0) + U(0)) floor_divide(T x, U y) { +inline common_dtype floor_divide(T x, U y) { const auto quot = x / y; return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot; } + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v && is_scalar_floating_point_v, + bool> = true> +inline common_dtype floor_divide(T x, U y) { + return ::metal::floor(x / y); +} + +// fmod +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_integral_v && is_scalar_integral_v, + bool> = true> +inline common_dtype fmod(T x, U y) { + return x % y; +} + +template < + typename T, + typename U, + ::metal::enable_if_t< + is_scalar_floating_point_v && is_scalar_floating_point_v, + bool> = true> +inline common_dtype fmod(T x, U y) { + return ::metal::fmod(x, y); +} + // cast_to primitives // - No-op if types as the same template < @@ -197,8 +232,6 @@ inline T cast_to(const U from) { } // Generalizable math operators (used for both scalar and complex) -template -using common_dtype = decltype(U(0) + V(0)); template < typename T, diff --git a/test/test_mps.py b/test/test_mps.py index 8b8a82c897e..3b015ea0e40 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -5587,6 +5587,10 @@ class TestMPS(TestCaseMPS): torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5) self.assertEqual(res_cpu, res_mps) + # Regression test for https://github.com/pytorch/pytorch/issues/154171 + # Essentially remained over integral types should rely on integers ops + self.assertEqual(torch.tensor(42309891, device='mps') % torch.tensor(31, device='mps'), torch.tensor(6, device='mps')) + def test_expand(self): def helper(n, c): values = [[1.0], [4.0], [7.0]] diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 13b4890f517..72dbfd676c2 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -581,7 +581,6 @@ if torch.backends.mps.is_available(): "mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], "matmul": [torch.int64] if MACOS_VERSION < 14.0 else [], "__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [], - "unravel_index": [torch.int32, torch.int64], # returned output on CPU is float64 "bincount": [ torch.int16, @@ -590,8 +589,6 @@ if torch.backends.mps.is_available(): torch.uint8, torch.int8, ], - # trunc_tensor not working properly for float16 and bfloat16 - "fmod": [torch.float16], # round not working properly for float16 and bfloat16 "round": [torch.float16, torch.bfloat16], "rounddecimals_0": [torch.bfloat16], @@ -928,8 +925,6 @@ if torch.backends.mps.is_available(): "signal.windows.kaiser": [torch.float32], "signal.windows.nuttall": [torch.float32], "eye": [torch.float16, torch.float32], - # trunc_tensor not working properly for float16 - "fmod": [torch.float16], # round not working properly for float16 "round": [torch.float16], # topk fails with duplicate indices @@ -942,7 +937,6 @@ if torch.backends.mps.is_available(): "masked.softmax": [torch.float32, torch.float16], "masked.log_softmax": [torch.float32, torch.float16], "atanh": [torch.float16], - "__rmod__": [torch.float16], "triangular_solve": [torch.float32], # Unsupported Border padding mode, forward pass success as fallback to cpu "grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],