mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
fix for fp16 (#134106)
This PR is a replacement for https://github.com/pytorch/pytorch/pull/133085 for pushing a quick fix for RMSNorm. The original author is @kkontny Previous PR summary: Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation. I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor. Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability. ``` class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ``` Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134106 Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
66db61f0d1
commit
9a04cfbeff
@@ -6,6 +6,7 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/native/cpu/mixed_data_type.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@@ -295,7 +296,12 @@ Tensor rms_norm(
|
||||
eps_val = eps.value();
|
||||
}
|
||||
|
||||
auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
|
||||
// upcast is needed for fp16 and bf16
|
||||
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
|
||||
Tensor upcasted_input = input.to(opmath_t);
|
||||
|
||||
Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val));
|
||||
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
|
||||
|
||||
if (weight_opt.has_value()) {
|
||||
result = result.mul(weight_opt.value());
|
||||
|
||||
@@ -4541,7 +4541,7 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar
|
||||
)
|
||||
|
||||
def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000)
|
||||
|
||||
# Ordered as input shape, normalized_shape and a kwarg dict for eps
|
||||
cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]
|
||||
|
||||
@@ -1937,7 +1937,9 @@ def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, tr
|
||||
normalized_shape = m.normalized_shape
|
||||
weight = m.weight
|
||||
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
|
||||
result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
|
||||
upcasted_i = i.float()
|
||||
result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
|
||||
result = result.type_as(i)
|
||||
if weight is not None:
|
||||
result *= weight
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user