mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
[MPSInductor] Fix remainder implementation for int types (#155891)
Introduce `c10::metal::remainder` and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.
This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.
This fixes compilation of something like
```python
@torch.compile
def f(x, y):
return x[y % 5]
```
which beforehand failed to compile with
```
torch._inductor.exc.InductorError: SyntaxError: failed to compile
#include <c10/metal/utils.h>
kernel void generated_kernel(
device float* out_ptr0,
constant long* in_ptr0,
constant float* in_ptr1,
uint xindex [[thread_position_in_grid]]
) {
int x0 = xindex;
auto tmp0 = in_ptr0[x0];
auto tmp1 = 12;
auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1));
auto tmp3 = 1024;
auto tmp4 = static_cast<long>(tmp3);
auto tmp5 = tmp2 + tmp4;
auto tmp6 = tmp2 < 0;
auto tmp7 = tmp6 ? tmp5 : tmp2;
if ((tmp7 < 0) && (tmp7 > 1024)) return;
auto tmp9 = in_ptr1[tmp7];
out_ptr0[x0] = static_cast<float>(tmp9);
}
with program_source:372:28: error: array subscript is not an integer
auto tmp9 = in_ptr1[tmp7];
^~~~~
```
This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using `transformers==4.44.2`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155891
Approved by: https://github.com/manuelcandales
This commit is contained in:
committed by
PyTorch MergeBot
parent
9462106b7e
commit
b6add8c8ba
@@ -265,6 +265,30 @@ inline common_dtype<T, U> div(const T x, const U y) {
|
||||
return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y);
|
||||
}
|
||||
|
||||
// Remainder operator
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
is_scalar_floating_point_v<T> || is_scalar_floating_point_v<U>,
|
||||
bool> = true>
|
||||
inline float remainder(const T x, const U y) {
|
||||
const auto x_f = static_cast<float>(x);
|
||||
const auto y_f = static_cast<float>(y);
|
||||
return x_f - y_f * floor_divide(x_f, y_f);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename U,
|
||||
::metal::enable_if_t<
|
||||
is_scalar_integral_v<T> && is_scalar_integral_v<U>,
|
||||
bool> = true>
|
||||
inline common_dtype<T, U> remainder(const T x, const U y) {
|
||||
auto rc = x % y;
|
||||
return rc == 0 || (x ^ y) > 0 ? rc : rc + y;
|
||||
}
|
||||
|
||||
// Based on algorithm described in
|
||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
||||
inline float log1p(float x) {
|
||||
|
||||
Reference in New Issue
Block a user