Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"

This reverts commit 71073caa00.

Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
This commit is contained in:
PyTorch MergeBot
2025-04-12 20:27:48 +00:00
parent 3dcb46c30e
commit 7762bddd87
3 changed files with 17 additions and 33 deletions

View File

@@ -91,7 +91,7 @@ opmath_t<T> threadgroup_prod(
}
template <typename T>
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float m = data[0];
float m2 = 0;
@@ -100,7 +100,7 @@ float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
m += delta / (idx + 1);
m2 += delta * (data[idx] - m);
}
return float3(m, m2, size);
return float2(m, m2);
}
// Each vec3type is tuple of mean, m2 and weight