mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Revert "[MPSInductor] Fix larger-than-threadgroup Welford reductions (#151152)"
This reverts commit 71073caa00.
Reverted https://github.com/pytorch/pytorch/pull/151152 on behalf of https://github.com/malfet due to Another lint failure ([comment](https://github.com/pytorch/pytorch/pull/151152#issuecomment-2799027274))
This commit is contained in:
@@ -91,7 +91,7 @@ opmath_t<T> threadgroup_prod(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
float m = data[0];
|
||||
float m2 = 0;
|
||||
@@ -100,7 +100,7 @@ float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
m += delta / (idx + 1);
|
||||
m2 += delta * (data[idx] - m);
|
||||
}
|
||||
return float3(m, m2, size);
|
||||
return float2(m, m2);
|
||||
}
|
||||
|
||||
// Each vec3type is tuple of mean, m2 and weight
|
||||
|
||||
Reference in New Issue
Block a user