diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index fe2df46c680..7e16aa1569a 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -91,7 +91,7 @@ opmath_t threadgroup_prod( } template -float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { +float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); float m = data[0]; float m2 = 0; @@ -100,7 +100,7 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) { m += delta / (idx + 1); m2 += delta * (data[idx] - m); } - return float2(m, m2); + return float3(m, m2, size); } // Each vec3type is tuple of mean, m2 and weight diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index e2b300b280e..525e21ac6ab 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -209,6 +209,7 @@ for test_name in [ "test_max_min", "test_max_pool2d2", "test_multilayer_prime_size", + "test_multilayer_var_lowp", "test_min_max_reduction_nan", "test_nan_to_num", "test_neg_max_uint8", diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 42964435c8f..8614790e3ce 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -640,31 +640,46 @@ class MetalKernel(SIMDKernel): dtype=dtype, ) if reduction_type == "welford_reduce": - assert not self.multistage_reduction, ( - f"Multistage reduction not yet supported for {reduction_type}" + if not self.multistage_reduction: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") + wf_res = self.cse.generate( + self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + ) + self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap( + (f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z") + ) + return result_tuple + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));" ) - acc_buf = self._new_idxvar(src_dtype, acc_buf_size) - self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( - self.compute, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + self.stores, + f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})", ) self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap( - (f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel) + (f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z") ) return result_tuple - if reduction_type == "welford_combine": - assert not self.multistage_reduction, ( - f"Multistage reduction not yet supported for {reduction_type}" - ) assert isinstance(value, tuple), "Input to welford combine must be tuple" acc_buf = self._new_idxvar("float3", acc_buf_size) - self.compute.splice( - f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});" - ) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + inp_value = f"float3({value[0]}, {value[1]}, {value[2]})" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + if self.multistage_reduction: + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});" + ) + else: + self.compute.writeline(f"{acc_thread_var} = {inp_value};") wf_res = self.cse.generate( - self.compute, + self.stores if self.multistage_reduction else self.compute, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", ) self.cse.reduction_cache[cache_key] = result_tuple = OpsWrapper._unwrap(