mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[MPS][BE] Move fmod/remainder to Metal ops (#154280)
This accomplishes following: - Fixes correctness problem with large integer types (though probably makes it slower, but this could not be avoided if one wants to compute accurate answer) - Makes op faster for floating point types (as Metal kernel invocation is faster than creating MPSGraph) - Eliminates need for several correctness workarounds Fixes https://github.com/pytorch/pytorch/issues/154171 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154280 Approved by: https://github.com/dcci ghstack dependencies: #154275, #154290
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f08bdb7f2
commit
975bbc63db
@@ -246,6 +246,20 @@ struct div_trunc_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct remainder_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(a - b * c10::metal::floor_divide(a, b));
|
||||
}
|
||||
};
|
||||
|
||||
struct fmod_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return c10::metal::fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper defines
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#define _METAL_310_PLUS(x) x
|
||||
@@ -304,6 +318,10 @@ REGISTER_FLOAT_BINARY_OP(div_trunc);
|
||||
REGISTER_INTEGER_BINARY_OP(div_trunc);
|
||||
REGISTER_OPMATH_FLOAT_BINARY_OP(div_true);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(div_true);
|
||||
REGISTER_OPMATH_FLOAT_BINARY_OP(remainder);
|
||||
REGISTER_INTEGER_BINARY_OP(remainder);
|
||||
REGISTER_OPMATH_FLOAT_BINARY_OP(fmod);
|
||||
REGISTER_INTEGER_BINARY_OP(fmod);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
|
||||
|
||||
@@ -159,6 +159,14 @@ static void div_trunc_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "div_trunc");
|
||||
}
|
||||
|
||||
static void remainder_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "remainder");
|
||||
}
|
||||
|
||||
static void fmod_mps_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_binary_kernel(iter, "fmod");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel)
|
||||
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel)
|
||||
REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel)
|
||||
@@ -178,4 +186,6 @@ REGISTER_DISPATCH(mul_stub, &mul_mps_kernel)
|
||||
REGISTER_DISPATCH(div_true_stub, &div_true_mps_kernel)
|
||||
REGISTER_DISPATCH(div_floor_stub, &div_floor_mps_kernel)
|
||||
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_mps_kernel)
|
||||
REGISTER_DISPATCH(fmod_stub, &fmod_mps_kernel)
|
||||
REGISTER_DISPATCH(remainder_stub, &remainder_mps_kernel)
|
||||
} // namespace at::native
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <ATen/ops/atan2_native.h>
|
||||
#include <ATen/ops/div_native.h>
|
||||
#include <ATen/ops/eq_native.h>
|
||||
#include <ATen/ops/fmod_native.h>
|
||||
#include <ATen/ops/ge_native.h>
|
||||
#include <ATen/ops/gt_native.h>
|
||||
#include <ATen/ops/hypot_native.h>
|
||||
@@ -30,7 +29,6 @@
|
||||
#include <ATen/ops/ne_native.h>
|
||||
#include <ATen/ops/pow.h>
|
||||
#include <ATen/ops/pow_native.h>
|
||||
#include <ATen/ops/remainder_native.h>
|
||||
#include <ATen/ops/result_type.h>
|
||||
#include <ATen/ops/sub_native.h>
|
||||
#include <ATen/ops/view_as_real.h>
|
||||
@@ -184,61 +182,6 @@ static void binaryOpScalar(const Tensor& self,
|
||||
binaryOpTensor(self, wrapped_scalar_tensor(other), output, op_name, binaryBlock);
|
||||
}
|
||||
|
||||
static void div_mode_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
std::optional<std::string_view> rounding_mode,
|
||||
const Tensor& output,
|
||||
const std::string& op_name) {
|
||||
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
|
||||
}
|
||||
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;
|
||||
if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) {
|
||||
primaryCastTensor = [mpsGraph castTensor:primaryCastTensor toType:MPSDataTypeFloat32 name:@"primaryCastTensor"];
|
||||
secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor
|
||||
toType:MPSDataTypeFloat32
|
||||
name:@"secondaryCastTensor"];
|
||||
}
|
||||
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
// Rounding is a no-op for integral types, and also a reasonable workaround
|
||||
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
|
||||
// See https://github.com/pytorch/pytorch/issues/84995
|
||||
bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
|
||||
if (!rounding_mode.has_value() || !isFloatOutput) {
|
||||
return divTensor;
|
||||
} else if (*rounding_mode == "trunc") {
|
||||
auto truncTensor = [mpsGraph truncateWithTensor:divTensor name:nil];
|
||||
if (op_name == "fmod_mps_out") {
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return truncTensor;
|
||||
} else if (*rounding_mode == "floor") {
|
||||
MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil];
|
||||
if (op_name == "remainder_out_mps") {
|
||||
auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor
|
||||
secondaryTensor:secondaryCastTensor
|
||||
name:nil];
|
||||
return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor secondaryTensor:mulTensor name:nil];
|
||||
}
|
||||
return floorTensor;
|
||||
}
|
||||
assert(0 && "Invalid rounding mode\n");
|
||||
return nullptr;
|
||||
};
|
||||
binaryOpTensor(self,
|
||||
other,
|
||||
output,
|
||||
op_name + "_mps:" + (rounding_mode.has_value() ? c10::str(*rounding_mode) : ""),
|
||||
div_mode_op_block);
|
||||
}
|
||||
|
||||
static void add_sub_lerp_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Scalar& alpha,
|
||||
@@ -352,14 +295,6 @@ TORCH_IMPL_FUNC(pow_Scalar_out_mps)(const Scalar& base, const Tensor& exp, const
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(remainder_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, "floor", output, "remainder_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(fmod_mps_out)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::div_mode_template(self, other, "trunc", output, "fmod_mps_out");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(hypot_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
|
||||
mps::BinaryOpBlock hypot_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
||||
@@ -9862,8 +9862,7 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: fmod_out
|
||||
MPS: fmod_mps_out
|
||||
CPU, CUDA, MPS: fmod_out
|
||||
tags: pointwise
|
||||
|
||||
- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
@@ -9969,8 +9968,7 @@
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: remainder_out
|
||||
MPS: remainder_out_mps
|
||||
CPU, CUDA, MPS: remainder_out
|
||||
tags: pointwise
|
||||
|
||||
- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
|
||||
@@ -148,6 +148,9 @@ template <typename T>
|
||||
constexpr constant bool is_scalar_integral_v =
|
||||
::metal::is_integral_v<T> && ::metal::is_scalar_v<T>;
|
||||
|
||||
template <typename U, typename V>
|
||||
using common_dtype = decltype(U(0) + V(0));
|
||||
|
||||
// floor_divide
|
||||
template <
|
||||
typename T,
|
||||
@@ -155,10 +158,42 @@ template <
|
||||
::metal::enable_if_t<
|
||||
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
|
||||
bool> = true>
|
||||
inline decltype(T(0) + U(0)) floor_divide(T x, U y) {
|
||||
inline common_dtype<T, U> floor_divide(T x, U y) {
|
||||
const auto quot = x / y;
|
||||
return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
|
||||
bool> = true>
|
||||
inline common_dtype<T, U> floor_divide(T x, U y) {
|
||||
return ::metal::floor(x / y);
|
||||
}
|
||||
|
||||
// fmod
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
|
||||
bool> = true>
|
||||
inline common_dtype<T, U> fmod(T x, U y) {
|
||||
return x % y;
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
is_scalar_floating_point_v<T> && is_scalar_floating_point_v<U>,
|
||||
bool> = true>
|
||||
inline common_dtype<T, U> fmod(T x, U y) {
|
||||
return ::metal::fmod(x, y);
|
||||
}
|
||||
|
||||
// cast_to primitives
|
||||
// - No-op if types as the same
|
||||
template <
|
||||
@@ -197,8 +232,6 @@ inline T cast_to(const U from) {
|
||||
}
|
||||
|
||||
// Generalizable math operators (used for both scalar and complex)
|
||||
template <typename U, typename V>
|
||||
using common_dtype = decltype(U(0) + V(0));
|
||||
|
||||
template <
|
||||
typename T,
|
||||
|
||||
@@ -5587,6 +5587,10 @@ class TestMPS(TestCaseMPS):
|
||||
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
|
||||
self.assertEqual(res_cpu, res_mps)
|
||||
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/154171
|
||||
# Essentially remained over integral types should rely on integers ops
|
||||
self.assertEqual(torch.tensor(42309891, device='mps') % torch.tensor(31, device='mps'), torch.tensor(6, device='mps'))
|
||||
|
||||
def test_expand(self):
|
||||
def helper(n, c):
|
||||
values = [[1.0], [4.0], [7.0]]
|
||||
|
||||
@@ -581,7 +581,6 @@ if torch.backends.mps.is_available():
|
||||
"mat": [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
"matmul": [torch.int64] if MACOS_VERSION < 14.0 else [],
|
||||
"__rmatmul__": [torch.int64] if MACOS_VERSION < 14.0 else [],
|
||||
"unravel_index": [torch.int32, torch.int64],
|
||||
# returned output on CPU is float64
|
||||
"bincount": [
|
||||
torch.int16,
|
||||
@@ -590,8 +589,6 @@ if torch.backends.mps.is_available():
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
],
|
||||
# trunc_tensor not working properly for float16 and bfloat16
|
||||
"fmod": [torch.float16],
|
||||
# round not working properly for float16 and bfloat16
|
||||
"round": [torch.float16, torch.bfloat16],
|
||||
"rounddecimals_0": [torch.bfloat16],
|
||||
@@ -928,8 +925,6 @@ if torch.backends.mps.is_available():
|
||||
"signal.windows.kaiser": [torch.float32],
|
||||
"signal.windows.nuttall": [torch.float32],
|
||||
"eye": [torch.float16, torch.float32],
|
||||
# trunc_tensor not working properly for float16
|
||||
"fmod": [torch.float16],
|
||||
# round not working properly for float16
|
||||
"round": [torch.float16],
|
||||
# topk fails with duplicate indices
|
||||
@@ -942,7 +937,6 @@ if torch.backends.mps.is_available():
|
||||
"masked.softmax": [torch.float32, torch.float16],
|
||||
"masked.log_softmax": [torch.float32, torch.float16],
|
||||
"atanh": [torch.float16],
|
||||
"__rmod__": [torch.float16],
|
||||
"triangular_solve": [torch.float32],
|
||||
# Unsupported Border padding mode, forward pass success as fallback to cpu
|
||||
"grid_sampler_2d": [torch.float32, torch.float16, torch.bfloat16],
|
||||
|
||||
Reference in New Issue
Block a user