diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index b11bcaba38e..6e0abb14f8b 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -295,7 +296,12 @@ Tensor rms_norm( eps_val = eps.value(); } - auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val))); + // upcast is needed for fp16 and bf16 + c10::ScalarType opmath_t = toOpMathType(input.scalar_type()); + Tensor upcasted_input = input.to(opmath_t); + + Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)); + Tensor result = upcasted_input.mul(rqrst_input).type_as(input); if (weight_opt.has_value()) { result = result.mul(weight_opt.value()); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a6f67d75659..a174154349f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4541,7 +4541,7 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar ) def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) # Ordered as input shape, normalized_shape and a kwarg dict for eps cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 00251ca264f..63963bab1b0 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1937,7 +1937,9 @@ def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, tr normalized_shape = m.normalized_shape weight = m.weight dims = [ndim - i - 1 for i in range(len(normalized_shape))] - result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + upcasted_i = i.float() + result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + result = result.type_as(i) if weight is not None: result *= weight return result