Add compilable foreach RAdam support (#117912)

Fixes https://github.com/pytorch/pytorch/issues/117807

This brings the number of supported optimizers with `torch.compile` to 11/13 (!)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117912
Approved by: https://github.com/janeyx99
This commit is contained in:
Michael Lazos
2024-01-27 04:32:27 +00:00
committed by PyTorch MergeBot
parent fe10b1800f
commit 800e2e823f
6 changed files with 286 additions and 48 deletions

View File

@@ -11,7 +11,7 @@ import torch
import torch._inductor
# The rest of the optimizers not yet imported: LBFGS, RAdam, SparseAdam
# LBFGS, SparseAdam not supported
from torch.optim import (
Adadelta,
Adagrad,
@@ -20,6 +20,7 @@ from torch.optim import (
AdamW,
ASGD,
NAdam,
RAdam,
RMSprop,
Rprop,
SGD,
@@ -69,6 +70,9 @@ KERNEL_COUNTS = {
Adagrad: KernelCounts(multitensor=5, singletensor=8),
ASGD: KernelCounts(multitensor=2, singletensor=12),
SGD: KernelCounts(multitensor=2, singletensor=8),
RAdam: KernelCounts(
multitensor=2, singletensor=None
), # Single tensor eager needs to be refactored to enable tracing
Adamax: KernelCounts(
multitensor=2, singletensor=None
), # Single tensor eager needs to be refactored to enable tracing

View File

@@ -3113,7 +3113,12 @@ exit(2)
for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True))
] + [
(optimizer_ctor, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay})
for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1))]
for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1))
] + [
(torch.optim.RAdam, {"lr": 0.1, "foreach": True, "decoupled_weight_decay": decoupled_weight_decay,
"weight_decay": weight_decay})
for decoupled_weight_decay, weight_decay in product((False, True), (0.0, 0.1))
]
for optimizer_ctor, kwargs in cases:
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):

View File

@@ -18,6 +18,13 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
# Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
# Radam does not support capturable single tensor
kwargs["capturable"] = False
@markDynamoStrictTest
class TestOptimRenewed(TestCase):
@@ -71,6 +78,9 @@ class TestOptimRenewed(TestCase):
weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0])
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
input = torch.randn(5, device=device, dtype=dtype)
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
def closure():
@@ -94,6 +104,7 @@ class TestOptimRenewed(TestCase):
self.assertLess(closure().item(), initial_value)
@onlyCUDA
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
@@ -102,6 +113,10 @@ class TestOptimRenewed(TestCase):
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
input = torch.randn(5, device="cuda:0", dtype=dtype)
@@ -138,6 +153,9 @@ class TestOptimRenewed(TestCase):
complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
@@ -175,6 +193,10 @@ class TestOptimRenewed(TestCase):
continue
for flag_value in (False, True):
kwargs[flag] = flag_value
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
input = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
).reshape(3, 2)
@@ -289,6 +311,8 @@ class TestOptimRenewed(TestCase):
p_clone.grad = p.grad.clone().detach()
params_clone.append(p_clone)
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
optimizer = optim_cls(params_clone, **kwargs)
for _ in range(kIterations):
optimizer.step()
@@ -361,6 +385,9 @@ class TestOptimRenewed(TestCase):
for flag_value in (False, True):
kwargs["foreach"] = flag_value
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
# 512 bytes anyway. We use 128 (since datasize would be 4 bytes) so that param
@@ -384,7 +411,7 @@ class TestOptimRenewed(TestCase):
st_max_mem, mt_max_mem = max_mems
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD"]:
if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
@@ -394,6 +421,11 @@ class TestOptimRenewed(TestCase):
# bias_correction, mus, and mu_nexts
nintermediates = 5
if optim_cls.__name__ == "RAdam":
# RAdam has four intermediates with capturable
# num, unrect_step_size, buffer, grouped_grads
nintermediates = 4
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
# Adagrad uses std and grads at the same time
@@ -552,6 +584,7 @@ class TestOptimRenewed(TestCase):
return torch.tensor([1], device=device, dtype=dtype)
for optim_input in all_optim_inputs:
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
optimizer = optim_cls(params, **optim_input.kwargs)
optimizer.step(closure)
self.assertEqual(old_params, params)
@@ -569,6 +602,7 @@ class TestOptimRenewed(TestCase):
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
# params will decay even if grads are empty if weight_decay != 0,
# and capturable doesn't work for CPU tensors
@@ -730,6 +764,12 @@ class TestOptimRenewed(TestCase):
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
# See https://github.com/pytorch/pytorch/issues/117836 for Adamax
# See https://github.com/pytorch/pytorch/issues/118230 for RAdam
if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
for _ in range(3):
optimizer.step(closure)

View File

@@ -1519,41 +1519,42 @@ class TorchPatcher:
sparse_adam,
}
disabled_multi_tensor_opt_modules = {
radam, # data-dependent control flow
excluded_single_tensor = {
radam, # https://github.com/pytorch/pytorch/issues/117807
}
for opt_mod in optimizer_modules:
opt_name = opt_mod.__name__.split(".")[-1]
multi_tensor_fn_name = f"_multi_tensor_{opt_name}"
fused_fn_name = f"_fused_{opt_name}"
if (
hasattr(opt_mod, multi_tensor_fn_name)
and opt_mod in disabled_multi_tensor_opt_modules
):
setattr(
opt_mod,
multi_tensor_fn_name,
disable(getattr(opt_mod, multi_tensor_fn_name)),
)
single_tensor_fn_name = f"_single_tensor_{opt_name}"
if hasattr(opt_mod, fused_fn_name):
setattr(
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name))
)
if (
hasattr(opt_mod, single_tensor_fn_name)
and opt_mod in excluded_single_tensor
):
setattr(
opt_mod,
single_tensor_fn_name,
disable(getattr(opt_mod, single_tensor_fn_name)),
)
optimizer_classes = [
opt
for opt in torch.optim.__dict__.values()
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
]
# Note: we don't support sparsity, data-dependent control, or tracing through backwards
# Note: we don't support sparsity or tracing through backwards
excluded_optimizer_classes = {
torch.optim.SparseAdam,
torch.optim.RAdam,
torch.optim.LBFGS,
}
for opt in optimizer_classes:
if opt in excluded_optimizer_classes:
opt.step = disable(opt.step)

View File

@@ -8,11 +8,11 @@ from .optimizer import (
Optimizer,
_default_to_fused_or_foreach,
_differentiable_doc,
_capturable_doc,
_dispatch_sqrt,
_foreach_doc,
_get_scalar_dtype,
_get_value,
_stack_if_compiling,
_use_grad_for_differentiable,
_view_as_real,
)
@@ -31,6 +31,7 @@ class RAdam(Optimizer):
decoupled_weight_decay: bool = False,
*,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
):
if not 0.0 <= lr:
@@ -49,6 +50,7 @@ class RAdam(Optimizer):
eps=eps,
weight_decay=weight_decay,
foreach=foreach,
capturable=capturable,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
)
@@ -60,13 +62,13 @@ class RAdam(Optimizer):
group.setdefault("foreach", None)
group.setdefault("differentiable", False)
group.setdefault("decoupled_weight_decay", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
group.setdefault("capturable", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps):
has_complex = False
@@ -81,7 +83,11 @@ class RAdam(Optimizer):
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
state['step'] = (
torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
if group['capturable']
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
@@ -132,6 +138,7 @@ class RAdam(Optimizer):
weight_decay=group["weight_decay"],
eps=group["eps"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
decoupled_weight_decay=group["decoupled_weight_decay"],
has_complex=has_complex,
@@ -201,6 +208,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm.
decay as in AdamW to obtain RAdamW (default: False)
{_foreach_doc}
{_differentiable_doc}
{_capturable_doc}
.. _On the variance of the adaptive learning rate and beyond:
https://arxiv.org/abs/1908.03265
@@ -223,6 +231,7 @@ def radam(
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
beta1: float,
@@ -265,6 +274,7 @@ def radam(
eps=eps,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
@@ -283,8 +293,11 @@ def _single_tensor_radam(
eps: float,
differentiable: bool,
decoupled_weight_decay: bool,
capturable: bool,
has_complex: bool,
):
if capturable:
raise RuntimeError("capturable is not supported for single tensor radam")
for i, param in enumerate(params):
grad = grads[i]
@@ -356,6 +369,7 @@ def _multi_tensor_radam(
eps: float,
decoupled_weight_decay: bool,
differentiable: bool,
capturable: bool,
has_complex: bool,
):
@@ -364,6 +378,11 @@ def _multi_tensor_radam(
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
"If capturable=True, params and state_steps must be CUDA tensors."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
for ((
grouped_params,
@@ -387,8 +406,21 @@ def _multi_tensor_radam(
# maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA
rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) /
(1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
if capturable:
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction1)
torch._foreach_add_(bias_correction1, 1)
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_mul_(bias_correction2, grouped_state_steps)
torch._foreach_mul_(bias_correction2, 2)
torch._foreach_div_(bias_correction2, bias_correction1)
torch._foreach_neg_(bias_correction2)
torch._foreach_add_(bias_correction2, rho_inf)
rho_t_list = bias_correction2
else:
rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) /
(1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
if weight_decay != 0:
if decoupled_weight_decay:
@@ -405,29 +437,67 @@ def _multi_tensor_radam(
# Delete the local intermediate since it won't be used anymore to save on peak memory
del grouped_grads
rect = [
_dispatch_sqrt(
(rho_t - 4)
* (rho_t - 2)
* rho_inf
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
)
if rho_t > 5
else 0
for rho_t in rho_t_list
]
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
if capturable:
num = torch._foreach_sub(rho_t_list, 4)
sub2 = torch._foreach_sub(rho_t_list, 2)
torch._foreach_mul_(num, sub2)
del sub2
torch._foreach_mul_(num, rho_inf)
rho_inf = ((rho_inf - 4) * (rho_inf - 2))
denom = torch._foreach_mul(rho_t_list, rho_inf)
torch._foreach_div_(num, denom)
del denom
torch._foreach_sqrt_(num)
# TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
rect = [torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)]
del num
del rho_t_list
unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect]
torch._foreach_mul_(unrect_step_size, lr)
bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps)
torch._foreach_neg_(bias_correction1)
torch._foreach_add_(bias_correction1, 1)
torch._foreach_div_(unrect_step_size, bias_correction1)
torch._foreach_neg_(unrect_step_size)
bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction2)
torch._foreach_add_(bias_correction2, 1)
torch._foreach_sqrt_(bias_correction2)
torch._foreach_mul_(bias_correction2, lr)
torch._foreach_mul_(bias_correction2, rect)
del rect
torch._foreach_neg_(bias_correction2)
torch._foreach_div_(bias_correction2, bias_correction1)
del bias_correction1
else:
rect = [
_dispatch_sqrt(
(rho_t - 4)
* (rho_t - 2)
* rho_inf
/ ((rho_inf - 4) * (rho_inf - 2) * rho_t)
)
if rho_t > 5
else 0
for rho_t in rho_t_list
]
unrectified = [0 if rect > 0 else 1.0 for rect in rect]
bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
unrect_step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
bias_correction2 = [
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
]
bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
unrect_step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)])
bias_correction2_sqrt_times_rect_step_size = [
_dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
]
buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
torch._foreach_add_(buffer, eps)
torch._foreach_div_(buffer, bias_correction2_sqrt_times_rect_step_size)
torch._foreach_div_(buffer, bias_correction2)
torch._foreach_reciprocal_(buffer)
torch._foreach_add_(buffer, unrect_step_size)

View File

@@ -629,6 +629,26 @@ def optim_error_inputs_func_nadam(device, dtype):
# Weird story bro, NAdam and RAdam do not have maximize.
def optim_inputs_func_radam(device=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={
"capturable": True,
"weight_decay": 0.1,
},
desc="capturable, weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"capturable": True,
"weight_decay": 0.1,
"decoupled_weight_decay": True,
},
desc="capturable, weight_decay, decoupled_weight_decay",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
@@ -641,7 +661,7 @@ def optim_inputs_func_radam(device=None):
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
desc="decoupled_weight_decay",
),
]
] + (cuda_supported_configs if "cuda" in str(device) else [])
def optim_error_inputs_func_radam(device, dtype):
@@ -1489,6 +1509,104 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115607"
),
"TestOptimRenewed",
"test_foreach_matches_forloop",
),
DecorateInfo(
toleranceOverride(
{
# previously atol=1e-7, rtol=1e-7
torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
}
),
"TestOptimRenewed",
"test_foreach_matches_forloop",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116494"
),
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/115607"
),
"TestOptimRenewed",
"test_can_load_older_state_dict",
device_type="cpu",
),
DecorateInfo(
toleranceOverride(
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
torch.float32: tol(atol=5e-04, rtol=0.01),
}
),
"TestOptimRenewed",
"test_mixed_device_dtype",
active_if=TEST_WITH_TORCHDYNAMO,
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_complex",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_step_is_noop_when_params_have_no_grad",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_load_nontensor_step",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_param_groups_weight_decay",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_mixed_device_dtype",
),
),
),
OptimizerInfo(