[MPSInductor] Fix remainder implementation for int types (#155891)

Introduce `c10::metal::remainder` and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.

This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.

This fixes compilation of something like
```python
@torch.compile
def f(x, y):
    return x[y % 5]
```

which beforehand failed to compile with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/utils.h>
    kernel void generated_kernel(
        device float* out_ptr0,
        constant long* in_ptr0,
        constant float* in_ptr1,
        uint xindex [[thread_position_in_grid]]
    ) {
        int x0 = xindex;
        auto tmp0 = in_ptr0[x0];
        auto tmp1 = 12;
        auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1));
        auto tmp3 = 1024;
        auto tmp4 = static_cast<long>(tmp3);
        auto tmp5 = tmp2 + tmp4;
        auto tmp6 = tmp2 < 0;
        auto tmp7 = tmp6 ? tmp5 : tmp2;
        if ((tmp7 < 0) && (tmp7 > 1024)) return;
        auto tmp9 = in_ptr1[tmp7];
        out_ptr0[x0] = static_cast<float>(tmp9);
    }
 with program_source:372:28: error: array subscript is not an integer
        auto tmp9 = in_ptr1[tmp7];
                           ^~~~~
```

This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using `transformers==4.44.2`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155891
Approved by: https://github.com/manuelcandales
This commit is contained in:
Nikita Shulga
2025-06-13 07:18:01 -07:00
committed by PyTorch MergeBot
parent 9462106b7e
commit b6add8c8ba
4 changed files with 32 additions and 19 deletions

View File

@@ -265,6 +265,30 @@ inline common_dtype<T, U> div(const T x, const U y) {
return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y);
}
// Remainder operator
template <
typename T,
typename U,
::metal::enable_if_t<
is_scalar_floating_point_v<T> || is_scalar_floating_point_v<U>,
bool> = true>
inline float remainder(const T x, const U y) {
const auto x_f = static_cast<float>(x);
const auto y_f = static_cast<float>(y);
return x_f - y_f * floor_divide(x_f, y_f);
}
template <
typename T,
typename U,
::metal::enable_if_t<
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
bool> = true>
inline common_dtype<T, U> remainder(const T x, const U y) {
auto rc = x % y;
return rc == 0 || (x ^ y) > 0 ? rc : rc + y;
}
// Based on algorithm described in
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {