fix for fp16 (#134106)

This PR is a replacement for https://github.com/pytorch/pytorch/pull/133085 for pushing a quick fix for RMSNorm.
The original author is @kkontny

Previous PR summary:
Since FP16 has quite small dynamic range it is very easy to overflow while computing `at::pow(input, 2)` , and it happens in real world computation.

I've tried to use `nn.RMSNorm` fused implementation instead of `LlamaRMSNorm` inside `transformers` implementation of Llama (`src/transformers/models/llama/modeling_llama.py`). It started to give wrong answers in Fp16 while still giving good in FP32. I figured out happens due to overflow while computing square of the input tensor.

Original `LLamaRMSNorm` implementation upcasts input to fp32 to prevent this and give better numerical stability.

```
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
```

Proposed commit fixed the issue. FP16 in RMSNorm has to be treated in special way, to be usable in real world implementations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134106
Approved by: https://github.com/mikaylagawarecki, https://github.com/eqy
This commit is contained in:
Mayank Mishra
2024-09-11 22:02:04 +00:00
committed by PyTorch MergeBot
parent 66db61f0d1
commit 9a04cfbeff
3 changed files with 11 additions and 3 deletions

View File

@@ -6,6 +6,7 @@
#include <ATen/Parallel.h>
#include <ATen/native/cpu/mixed_data_type.h>
#include <c10/util/irange.h>
#include <ATen/OpMathType.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -295,7 +296,12 @@ Tensor rms_norm(
eps_val = eps.value();
}
auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)));
// upcast is needed for fp16 and bf16
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
Tensor upcasted_input = input.to(opmath_t);
Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val));
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
if (weight_opt.has_value()) {
result = result.mul(weight_opt.value());

View File

@@ -4541,7 +4541,7 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar
)
def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000)
# Ordered as input shape, normalized_shape and a kwarg dict for eps
cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment]

View File

@@ -1937,7 +1937,9 @@ def module_inputs_torch_nn_RMSNorm(module_info, device, dtype, requires_grad, tr
normalized_shape = m.normalized_shape
weight = m.weight
dims = [ndim - i - 1 for i in range(len(normalized_shape))]
result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
upcasted_i = i.float()
result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps)
result = result.type_as(i)
if weight is not None:
result *= weight
return result