mirror of
https://github.com/zebrajr/pytorch.git
synced 2026-01-15 12:15:51 +00:00
Add compilable foreach RAdam support (#117912)
Fixes https://github.com/pytorch/pytorch/issues/117807 This brings the number of supported optimizers with `torch.compile` to 11/13 (!) Pull Request resolved: https://github.com/pytorch/pytorch/pull/117912 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe10b1800f
commit
800e2e823f
@@ -11,7 +11,7 @@ import torch
|
||||
|
||||
import torch._inductor
|
||||
|
||||
# The rest of the optimizers not yet imported: LBFGS, RAdam, SparseAdam
|
||||
# LBFGS, SparseAdam not supported
|
||||
from torch.optim import (
|
||||
Adadelta,
|
||||
Adagrad,
|
||||
@@ -20,6 +20,7 @@ from torch.optim import (
|
||||
AdamW,
|
||||
ASGD,
|
||||
NAdam,
|
||||
RAdam,
|
||||
RMSprop,
|
||||
Rprop,
|
||||
SGD,
|
||||
@@ -69,6 +70,9 @@ KERNEL_COUNTS = {
|
||||
Adagrad: KernelCounts(multitensor=5, singletensor=8),
|
||||
ASGD: KernelCounts(multitensor=2, singletensor=12),
|
||||
SGD: KernelCounts(multitensor=2, singletensor=8),
|
||||
RAdam: KernelCounts(
|
||||
multitensor=2, singletensor=None
|
||||
), # Single tensor eager needs to be refactored to enable tracing
|
||||
Adamax: KernelCounts(
|
||||
multitensor=2, singletensor=None
|
||||
), # Single tensor eager needs to be refactored to enable tracing
|
||||
|
||||
@@ -3113,7 +3113,12 @@ exit(2)
|
||||
for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True))
|
||||
] + [
|
||||
(optimizer_ctor, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay})
|
||||
for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1))]
|
||||
for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1))
|
||||
] + [
|
||||
(torch.optim.RAdam, {"lr": 0.1, "foreach": True, "decoupled_weight_decay": decoupled_weight_decay,
|
||||
"weight_decay": weight_decay})
|
||||
for decoupled_weight_decay, weight_decay in product((False, True), (0.0, 0.1))
|
||||
]
|
||||
|
||||
for optimizer_ctor, kwargs in cases:
|
||||
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
|
||||
|
||||
@@ -18,6 +18,13 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri
|
||||
|
||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||
|
||||
|
||||
def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
|
||||
# Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
|
||||
if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
|
||||
# Radam does not support capturable single tensor
|
||||
kwargs["capturable"] = False
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestOptimRenewed(TestCase):
|
||||
|
||||
@@ -71,6 +78,9 @@ class TestOptimRenewed(TestCase):
|
||||
weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0])
|
||||
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
|
||||
input = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
|
||||
def closure():
|
||||
@@ -94,6 +104,7 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
||||
@@ -102,6 +113,10 @@ class TestOptimRenewed(TestCase):
|
||||
for optim_input in optim_inputs:
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
|
||||
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
||||
input = torch.randn(5, device="cuda:0", dtype=dtype)
|
||||
@@ -138,6 +153,9 @@ class TestOptimRenewed(TestCase):
|
||||
complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
|
||||
real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
|
||||
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
|
||||
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
|
||||
|
||||
@@ -175,6 +193,10 @@ class TestOptimRenewed(TestCase):
|
||||
continue
|
||||
for flag_value in (False, True):
|
||||
kwargs[flag] = flag_value
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
|
||||
input = torch.tensor(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
|
||||
).reshape(3, 2)
|
||||
@@ -289,6 +311,8 @@ class TestOptimRenewed(TestCase):
|
||||
p_clone.grad = p.grad.clone().detach()
|
||||
params_clone.append(p_clone)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
optimizer = optim_cls(params_clone, **kwargs)
|
||||
for _ in range(kIterations):
|
||||
optimizer.step()
|
||||
@@ -361,6 +385,9 @@ class TestOptimRenewed(TestCase):
|
||||
for flag_value in (False, True):
|
||||
kwargs["foreach"] = flag_value
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
|
||||
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
|
||||
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
|
||||
# 512 bytes anyway. We use 128 (since datasize would be 4 bytes) so that param
|
||||
@@ -384,7 +411,7 @@ class TestOptimRenewed(TestCase):
|
||||
st_max_mem, mt_max_mem = max_mems
|
||||
intermediate_size = nparams * param.nelement() * param.element_size()
|
||||
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
|
||||
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD"]:
|
||||
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
|
||||
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
|
||||
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
|
||||
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
|
||||
@@ -394,6 +421,11 @@ class TestOptimRenewed(TestCase):
|
||||
# bias_correction, mus, and mu_nexts
|
||||
nintermediates = 5
|
||||
|
||||
if optim_cls.__name__ == "RAdam":
|
||||
# RAdam has four intermediates with capturable
|
||||
# num, unrect_step_size, buffer, grouped_grads
|
||||
nintermediates = 4
|
||||
|
||||
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
|
||||
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
|
||||
# Adagrad uses std and grads at the same time
|
||||
@@ -552,6 +584,7 @@ class TestOptimRenewed(TestCase):
|
||||
return torch.tensor([1], device=device, dtype=dtype)
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
optimizer.step(closure)
|
||||
self.assertEqual(old_params, params)
|
||||
@@ -569,6 +602,7 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
|
||||
# params will decay even if grads are empty if weight_decay != 0,
|
||||
# and capturable doesn't work for CPU tensors
|
||||
@@ -730,6 +764,12 @@ class TestOptimRenewed(TestCase):
|
||||
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
# See https://github.com/pytorch/pytorch/issues/117836 for Adamax
|
||||
# See https://github.com/pytorch/pytorch/issues/118230 for RAdam
|
||||
if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
for _ in range(3):
|
||||
optimizer.step(closure)
|
||||
|
||||
@@ -1519,41 +1519,42 @@ class TorchPatcher:
|
||||
sparse_adam,
|
||||
}
|
||||
|
||||
disabled_multi_tensor_opt_modules = {
|
||||
radam, # data-dependent control flow
|
||||
excluded_single_tensor = {
|
||||
radam, # https://github.com/pytorch/pytorch/issues/117807
|
||||
}
|
||||
|
||||
for opt_mod in optimizer_modules:
|
||||
opt_name = opt_mod.__name__.split(".")[-1]
|
||||
multi_tensor_fn_name = f"_multi_tensor_{opt_name}"
|
||||
fused_fn_name = f"_fused_{opt_name}"
|
||||
if (
|
||||
hasattr(opt_mod, multi_tensor_fn_name)
|
||||
and opt_mod in disabled_multi_tensor_opt_modules
|
||||
):
|
||||
setattr(
|
||||
opt_mod,
|
||||
multi_tensor_fn_name,
|
||||
disable(getattr(opt_mod, multi_tensor_fn_name)),
|
||||
)
|
||||
single_tensor_fn_name = f"_single_tensor_{opt_name}"
|
||||
|
||||
if hasattr(opt_mod, fused_fn_name):
|
||||
setattr(
|
||||
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(opt_mod, single_tensor_fn_name)
|
||||
and opt_mod in excluded_single_tensor
|
||||
):
|
||||
setattr(
|
||||
opt_mod,
|
||||
single_tensor_fn_name,
|
||||
disable(getattr(opt_mod, single_tensor_fn_name)),
|
||||
)
|
||||
|
||||
optimizer_classes = [
|
||||
opt
|
||||
for opt in torch.optim.__dict__.values()
|
||||
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
|
||||
]
|
||||
|
||||
# Note: we don't support sparsity, data-dependent control, or tracing through backwards
|
||||
# Note: we don't support sparsity or tracing through backwards
|
||||
excluded_optimizer_classes = {
|
||||
torch.optim.SparseAdam,
|
||||
torch.optim.RAdam,
|
||||
torch.optim.LBFGS,
|
||||
}
|
||||
|
||||
for opt in optimizer_classes:
|
||||
if opt in excluded_optimizer_classes:
|
||||
opt.step = disable(opt.step)
|
||||
|
||||
@@ -8,11 +8,11 @@ from .optimizer import (
|
||||
Optimizer,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_capturable_doc,
|
||||
_dispatch_sqrt,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_stack_if_compiling,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
@@ -31,6 +31,7 @@ class RAdam(Optimizer):
|
||||
decoupled_weight_decay: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
@@ -49,6 +50,7 @@ class RAdam(Optimizer):
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
@@ -60,13 +62,13 @@ class RAdam(Optimizer):
|
||||
group.setdefault("foreach", None)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("decoupled_weight_decay", False)
|
||||
state_values = list(self.state.values())
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]["step"]
|
||||
)
|
||||
if not step_is_tensor:
|
||||
for s in state_values:
|
||||
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
|
||||
group.setdefault("capturable", False)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
|
||||
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps):
|
||||
has_complex = False
|
||||
@@ -81,7 +83,11 @@ class RAdam(Optimizer):
|
||||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
state['step'] = (
|
||||
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group['capturable']
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
@@ -132,6 +138,7 @@ class RAdam(Optimizer):
|
||||
weight_decay=group["weight_decay"],
|
||||
eps=group["eps"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
differentiable=group["differentiable"],
|
||||
decoupled_weight_decay=group["decoupled_weight_decay"],
|
||||
has_complex=has_complex,
|
||||
@@ -201,6 +208,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm.
|
||||
decay as in AdamW to obtain RAdamW (default: False)
|
||||
{_foreach_doc}
|
||||
{_differentiable_doc}
|
||||
{_capturable_doc}
|
||||
|
||||
.. _On the variance of the adaptive learning rate and beyond:
|
||||
https://arxiv.org/abs/1908.03265
|
||||
@@ -223,6 +231,7 @@ def radam(
|
||||
decoupled_weight_decay: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
@@ -265,6 +274,7 @@ def radam(
|
||||
eps=eps,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@@ -283,8 +293,11 @@ def _single_tensor_radam(
|
||||
eps: float,
|
||||
differentiable: bool,
|
||||
decoupled_weight_decay: bool,
|
||||
capturable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
if capturable:
|
||||
raise RuntimeError("capturable is not supported for single tensor radam")
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
@@ -356,6 +369,7 @@ def _multi_tensor_radam(
|
||||
eps: float,
|
||||
decoupled_weight_decay: bool,
|
||||
differentiable: bool,
|
||||
capturable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
|
||||
@@ -364,6 +378,11 @@ def _multi_tensor_radam(
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
|
||||
"If capturable=True, params and state_steps must be CUDA tensors."
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
|
||||
for ((
|
||||
grouped_params,
|
||||
@@ -387,8 +406,21 @@ def _multi_tensor_radam(
|
||||
# maximum length of the approximated SMA
|
||||
rho_inf = 2 / (1 - beta2) - 1
|
||||
# compute the length of the approximated SMA
|
||||
rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) /
|
||||
(1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
|
||||
if capturable:
|
||||
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
|
||||
torch._foreach_neg_(bias_correction1)
|
||||
torch._foreach_add_(bias_correction1, 1)
|
||||
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
|
||||
torch._foreach_mul_(bias_correction2, grouped_state_steps)
|
||||
torch._foreach_mul_(bias_correction2, 2)
|
||||
torch._foreach_div_(bias_correction2, bias_correction1)
|
||||
torch._foreach_neg_(bias_correction2)
|
||||
torch._foreach_add_(bias_correction2, rho_inf)
|
||||
rho_t_list = bias_correction2
|
||||
else:
|
||||
rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) /
|
||||
(1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
|
||||
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled_weight_decay:
|
||||
@@ -405,29 +437,67 @@ def _multi_tensor_radam(
|
||||
# Delete the local intermediate since it won't be used anymore to save on peak memory
|
||||
del grouped_grads
|
||||
|
||||
rect = [
|
||||
_dispatch_sqrt(
|
||||
(rho_t - 4)
|
||||
* (rho_t - 2)
|
||||
* rho_inf
|
||||
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
|
||||
)
|
||||
if rho_t > 5
|
||||
else 0
|
||||
for rho_t in rho_t_list
|
||||
]
|
||||
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
|
||||
if capturable:
|
||||
num = torch._foreach_sub(rho_t_list, 4)
|
||||
sub2 = torch._foreach_sub(rho_t_list, 2)
|
||||
torch._foreach_mul_(num, sub2)
|
||||
del sub2
|
||||
torch._foreach_mul_(num, rho_inf)
|
||||
rho_inf = ((rho_inf - 4) * (rho_inf - 2))
|
||||
denom = torch._foreach_mul(rho_t_list, rho_inf)
|
||||
torch._foreach_div_(num, denom)
|
||||
del denom
|
||||
torch._foreach_sqrt_(num)
|
||||
|
||||
# TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
|
||||
rect = [torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)]
|
||||
del num
|
||||
del rho_t_list
|
||||
unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
|
||||
torch._foreach_mul_(unrect_step_size, lr)
|
||||
|
||||
bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps)
|
||||
torch._foreach_neg_(bias_correction1)
|
||||
torch._foreach_add_(bias_correction1, 1)
|
||||
|
||||
torch._foreach_div_(unrect_step_size, bias_correction1)
|
||||
torch._foreach_neg_(unrect_step_size)
|
||||
|
||||
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
|
||||
torch._foreach_neg_(bias_correction2)
|
||||
torch._foreach_add_(bias_correction2, 1)
|
||||
torch._foreach_sqrt_(bias_correction2)
|
||||
torch._foreach_mul_(bias_correction2, lr)
|
||||
torch._foreach_mul_(bias_correction2, rect)
|
||||
del rect
|
||||
torch._foreach_neg_(bias_correction2)
|
||||
torch._foreach_div_(bias_correction2, bias_correction1)
|
||||
del bias_correction1
|
||||
else:
|
||||
rect = [
|
||||
_dispatch_sqrt(
|
||||
(rho_t - 4)
|
||||
* (rho_t - 2)
|
||||
* rho_inf
|
||||
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
|
||||
)
|
||||
if rho_t > 5
|
||||
else 0
|
||||
for rho_t in rho_t_list
|
||||
]
|
||||
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
|
||||
|
||||
bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
|
||||
unrect_step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
|
||||
bias_correction2 = [
|
||||
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
|
||||
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
|
||||
]
|
||||
|
||||
bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
|
||||
unrect_step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)])
|
||||
bias_correction2_sqrt_times_rect_step_size = [
|
||||
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
|
||||
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
|
||||
]
|
||||
|
||||
buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
|
||||
torch._foreach_add_(buffer, eps)
|
||||
torch._foreach_div_(buffer, bias_correction2_sqrt_times_rect_step_size)
|
||||
torch._foreach_div_(buffer, bias_correction2)
|
||||
torch._foreach_reciprocal_(buffer)
|
||||
torch._foreach_add_(buffer, unrect_step_size)
|
||||
|
||||
|
||||
@@ -629,6 +629,26 @@ def optim_error_inputs_func_nadam(device, dtype):
|
||||
|
||||
# Weird story bro, NAdam and RAdam do not have maximize.
|
||||
def optim_inputs_func_radam(device=None):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"capturable": True,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
desc="capturable, weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"capturable": True,
|
||||
"weight_decay": 0.1,
|
||||
"decoupled_weight_decay": True,
|
||||
},
|
||||
desc="capturable, weight_decay, decoupled_weight_decay",
|
||||
),
|
||||
]
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
|
||||
@@ -641,7 +661,7 @@ def optim_inputs_func_radam(device=None):
|
||||
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
|
||||
desc="decoupled_weight_decay",
|
||||
),
|
||||
]
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
def optim_error_inputs_func_radam(device, dtype):
|
||||
@@ -1489,6 +1509,104 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115607"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_foreach_matches_forloop",
|
||||
),
|
||||
DecorateInfo(
|
||||
toleranceOverride(
|
||||
{
|
||||
# previously atol=1e-7, rtol=1e-7
|
||||
torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
|
||||
}
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_foreach_matches_forloop",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/115607"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_can_load_older_state_dict",
|
||||
device_type="cpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
toleranceOverride(
|
||||
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
|
||||
torch.float32: tol(atol=5e-04, rtol=0.01),
|
||||
}
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_mixed_device_dtype",
|
||||
active_if=TEST_WITH_TORCHDYNAMO,
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_complex",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_when_params_have_no_grad",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_load_nontensor_step",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_weight_decay",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_lr",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_for_zero_grads",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_mixed_device_dtype",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
|
||||
Reference in New Issue
Block a user