[MPS][BE] Move fmod/remainder to Metal ops (#154280)

This accomplishes following:
 - Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer)
 - Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph)
 - Eliminates need for several correctness workarounds

Fixes https://github.com/pytorch/pytorch/issues/154171
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280
Approved by: https://github.com/dcci
ghstack dependencies: #154275, #154290
This commit is contained in:
Nikita Shulga
2025-05-23 18:41:05 -07:00
committed by PyTorch MergeBot
parent 8f08bdb7f2
commit 975bbc63db
7 changed files with 70 additions and 78 deletions

View File

@@ -246,6 +246,20 @@ struct div_trunc_functor {
}
};
struct remainder_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(a - b * c10::metal::floor_divide(a, b));
}
};
struct fmod_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return c10::metal::fmod(a, b);
}
};
// Some helper defines
#if __METAL_VERSION__ >= 310
#define _METAL_310_PLUS(x) x
@@ -304,6 +318,10 @@ REGISTER_FLOAT_BINARY_OP(div_trunc);
REGISTER_INTEGER_BINARY_OP(div_trunc);
REGISTER_OPMATH_FLOAT_BINARY_OP(div_true);
REGISTER_INT2FLOAT_BINARY_OP(div_true);
REGISTER_OPMATH_FLOAT_BINARY_OP(remainder);
REGISTER_INTEGER_BINARY_OP(remainder);
REGISTER_OPMATH_FLOAT_BINARY_OP(fmod);
REGISTER_INTEGER_BINARY_OP(fmod);
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);

View File

@@ -159,6 +159,14 @@ static void div_trunc_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "div_trunc");
}
static void remainder_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "remainder");
}
static void fmod_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "fmod");
}
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
REGISTER_DISPATCH(copysign_stub, &copysign_mps_kernel)
@@ -178,4 +186,6 @@ REGISTER_DISPATCH(mul_stub, &mul_mps_kernel)
REGISTER_DISPATCH(div_true_stub, &div_true_mps_kernel)
REGISTER_DISPATCH(div_floor_stub, &div_floor_mps_kernel)
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_mps_kernel)
REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
} // namespace at::native

View File

@@ -14,7 +14,6 @@
#include <ATen/ops/atan2_native.h>
#include <ATen/ops/div_native.h>
#include <ATen/ops/eq_native.h>
#include <ATen/ops/fmod_native.h>
#include <ATen/ops/ge_native.h>
#include <ATen/ops/gt_native.h>
#include <ATen/ops/hypot_native.h>
@@ -30,7 +29,6 @@
#include <ATen/ops/ne_native.h>
#include <ATen/ops/pow.h>
#include <ATen/ops/pow_native.h>
#include <ATen/ops/remainder_native.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/sub_native.h>
#include <ATen/ops/view_as_real.h>
@@ -184,61 +182,6 @@ static void binaryOpScalar(const Tensor& self,
binaryOpTensor(self, wrapped_scalar_tensor(other), output, op_name, binaryBlock);
}
static void div_mode_template(const Tensor& self,
const Tensor& other,
std::optional<std::string_view> rounding_mode,
const Tensor& output,
const std::string& op_name) {
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
}
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"];
secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
toType:MPSDataTypeFloat32
name:@"secondaryCastTensor"];
}
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
secondaryTensor:secondaryCastTensor
name:nil];
// Rounding is a no-op for integral types, and also a reasonable workaround
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
// See https://github.com/pytorch/pytorch/issues/84995
bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
if (!rounding_mode.has_value() || !isFloatOutput) {
return divTensor;
} else if (*rounding_mode == "trunc") {
auto truncTensor = [mpsGraph truncateWithTensor:divTensor name:nil];
if (op_name == "fmod_mps_out") {
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
secondaryTensor:secondaryCastTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
}
return truncTensor;
} else if (*rounding_mode == "floor") {
MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil];
if (op_name == "remainder_out_mps") {
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
secondaryTensor:secondaryCastTensor
name:nil];
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
}
return floorTensor;
}
assert(0 && "Invalid rounding mode\n");
return nullptr;
};
binaryOpTensor(self,
other,
output,
op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""),
div_mode_op_block);
}
static void add_sub_lerp_template(const Tensor& self,
const Tensor& other,
const Scalar& alpha,
@@ -352,14 +295,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
}
}
TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::div_mode_template(self, other, "floor", output, "remainder_out_mps");
}
TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
}
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

View File

@@ -9862,8 +9862,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: fmod_out
MPS: fmod_mps_out
CPU, CUDA, MPS: fmod_out
tags: pointwise
- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor
@@ -9969,8 +9968,7 @@
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: remainder_out
MPS: remainder_out_mps
CPU, CUDA, MPS: remainder_out
tags: pointwise
- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor

View File

@@ -148,6 +148,9 @@ template <typename T>
constexpr constant bool is_scalar_integral_v =
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
template <typename U, typename V>
using common_dtype = decltype(U(0) + V(0));
// floor_divide
template <
typename T,
@@ -155,10 +158,42 @@ template <
::metal::enable_if_t<
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
bool> = true>
inline decltype(T(0) + U(0)) floor_divide(T x, U y) {
inline common_dtype<T, U> floor_divide(T x, U y) {
const auto quot = x / y;
return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot;
}
template <
typename T,
typename U,
::metal::enable_if_t<
is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
bool> = true>
inline common_dtype<T, U> floor_divide(T x, U y) {
return ::metal::floor(x / y);
}
// fmod
template <
typename T,
typename U,
::metal::enable_if_t<
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
bool> = true>
inline common_dtype<T, U> fmod(T x, U y) {
return x % y;
}
template <
typename T,
typename U,
::metal::enable_if_t<
is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
bool> = true>
inline common_dtype<T, U> fmod(T x, U y) {
return ::metal::fmod(x, y);
}
// cast_to primitives
// - No-op if types as the same
template <
@@ -197,8 +232,6 @@ inline T cast_to(const U from) {
}
// Generalizable math operators (used for both scalar and complex)
template <typename U, typename V>
using common_dtype = decltype(U(0) + V(0));
template <
typename T,

View File

@@ -5587,6 +5587,10 @@ class TestMPS(TestCaseMPS):
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
self.assertEqual(res_cpu, res_mps)
# Regression test for https://github.com/pytorch/pytorch/issues/154171
# Essentially remained over integral types should rely on integers ops
self.assertEqual(torch.tensor(42309891, device='mps') % torch.tensor(31, device='mps'), torch.tensor(6, device='mps'))
def test_expand(self):
def helper(n, c):
values = [[1.0], [4.0], [7.0]]

View File

@@ -581,7 +581,6 @@ if torch.backends.mps.is_available():
"mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
"matmul": [torch.int64] if MACOS_VERSION < 14.0 else [],
"__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [],
"unravel_index": [torch.int32, torch.int64],
# returned output on CPU is float64
"bincount": [
torch.int16,
@@ -590,8 +589,6 @@ if torch.backends.mps.is_available():
torch.uint8,
torch.int8,
],
# trunc_tensor not working properly for float16 and bfloat16
"fmod": [torch.float16],
# round not working properly for float16 and bfloat16
"round": [torch.float16, torch.bfloat16],
"rounddecimals_0": [torch.bfloat16],
@@ -928,8 +925,6 @@ if torch.backends.mps.is_available():
"signal.windows.kaiser": [torch.float32],
"signal.windows.nuttall": [torch.float32],
"eye": [torch.float16, torch.float32],
# trunc_tensor not working properly for float16
"fmod": [torch.float16],
# round not working properly for float16
"round": [torch.float16],
# topk fails with duplicate indices
@@ -942,7 +937,6 @@ if torch.backends.mps.is_available():
"masked.softmax": [torch.float32, torch.float16],
"masked.log_softmax": [torch.float32, torch.float16],
"atanh": [torch.float16],
"__rmod__": [torch.float16],
"triangular_solve": [torch.float32],
# Unsupported Border padding mode, forward pass success as fallback to cpu
"grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],