diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 8e0eb4ff8b6..636c981b566 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -11,7 +11,7 @@ import torch import torch._inductor -# The rest of the optimizers not yet imported: LBFGS, RAdam, SparseAdam +# LBFGS, SparseAdam not supported from torch.optim import ( Adadelta, Adagrad, @@ -20,6 +20,7 @@ from torch.optim import ( AdamW, ASGD, NAdam, + RAdam, RMSprop, Rprop, SGD, @@ -69,6 +70,9 @@ KERNEL_COUNTS = { Adagrad: KernelCounts(multitensor=5, singletensor=8), ASGD: KernelCounts(multitensor=2, singletensor=12), SGD: KernelCounts(multitensor=2, singletensor=8), + RAdam: KernelCounts( + multitensor=2, singletensor=None + ), # Single tensor eager needs to be refactored to enable tracing Adamax: KernelCounts( multitensor=2, singletensor=None ), # Single tensor eager needs to be refactored to enable tracing diff --git a/test/test_cuda.py b/test/test_cuda.py index f855bcbb9b2..3f9143fca5a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3113,7 +3113,12 @@ exit(2) for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True)) ] + [ (optimizer_ctor, {"lr": 0.1, "foreach": True, "maximize": maximize, "weight_decay": weight_decay}) - for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1))] + for optimizer_ctor, maximize, weight_decay in product((torch.optim.Adamax, torch.optim.ASGD), (False, True), (0, 0.1)) + ] + [ + (torch.optim.RAdam, {"lr": 0.1, "foreach": True, "decoupled_weight_decay": decoupled_weight_decay, + "weight_decay": weight_decay}) + for decoupled_weight_decay, weight_decay in product((False, True), (0.0, 0.1)) + ] for optimizer_ctor, kwargs in cases: with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): diff --git a/test/test_optim.py b/test/test_optim.py index 73c08984d7e..2fa92985c73 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -18,6 +18,13 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4} + +def _make_radam_single_tensor_non_capturable(optim_cls, kwargs): + # Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed + if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False): + # Radam does not support capturable single tensor + kwargs["capturable"] = False + @markDynamoStrictTest class TestOptimRenewed(TestCase): @@ -71,6 +78,9 @@ class TestOptimRenewed(TestCase): weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]) bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0]) input = torch.randn(5, device=device, dtype=dtype) + + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs) optimizer = optim_cls([weight, bias], **optim_input.kwargs) def closure(): @@ -94,6 +104,7 @@ class TestOptimRenewed(TestCase): self.assertLess(closure().item(), initial_value) + @onlyCUDA @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") @optims(optim_db, dtypes=[torch.float32]) def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info): @@ -102,6 +113,10 @@ class TestOptimRenewed(TestCase): for optim_input in optim_inputs: if "foreach" in optim_info.supported_impls: optim_input.kwargs["foreach"] = False # force forloop + + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs) + weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype)) bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype)) input = torch.randn(5, device="cuda:0", dtype=dtype) @@ -138,6 +153,9 @@ class TestOptimRenewed(TestCase): complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params] + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs) + complex_optimizer = optim_cls(complex_params, **optim_input.kwargs) real_optimizer = optim_cls(real_params, **optim_input.kwargs) @@ -175,6 +193,10 @@ class TestOptimRenewed(TestCase): continue for flag_value in (False, True): kwargs[flag] = flag_value + + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, kwargs) + input = torch.tensor( [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device ).reshape(3, 2) @@ -289,6 +311,8 @@ class TestOptimRenewed(TestCase): p_clone.grad = p.grad.clone().detach() params_clone.append(p_clone) + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, kwargs) optimizer = optim_cls(params_clone, **kwargs) for _ in range(kIterations): optimizer.step() @@ -361,6 +385,9 @@ class TestOptimRenewed(TestCase): for flag_value in (False, True): kwargs["foreach"] = flag_value + # https://github.com/pytorch/pytorch/issues/118230 + _make_radam_single_tensor_non_capturable(optim_cls, kwargs) + # The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512, # meaning any tensor that occupies <512 bytes of memory will allocate a whole # 512 bytes anyway. We use 128 (since datasize would be 4 bytes) so that param @@ -384,7 +411,7 @@ class TestOptimRenewed(TestCase): st_max_mem, mt_max_mem = max_mems intermediate_size = nparams * param.nelement() * param.element_size() nintermediates = 1 # we expect a budget of 1 intermediate most of the time - if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD"]: + if kwargs.get('capturable') or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]: # with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections # with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps) # ASGD allocates axs, 2x mus, 2x etas, and grads at the same time @@ -394,6 +421,11 @@ class TestOptimRenewed(TestCase): # bias_correction, mus, and mu_nexts nintermediates = 5 + if optim_cls.__name__ == "RAdam": + # RAdam has four intermediates with capturable + # num, unrect_step_size, buffer, grouped_grads + nintermediates = 4 + elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop"]: # NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt) # Adagrad uses std and grads at the same time @@ -552,6 +584,7 @@ class TestOptimRenewed(TestCase): return torch.tensor([1], device=device, dtype=dtype) for optim_input in all_optim_inputs: + _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs) optimizer = optim_cls(params, **optim_input.kwargs) optimizer.step(closure) self.assertEqual(old_params, params) @@ -569,6 +602,7 @@ class TestOptimRenewed(TestCase): for optim_input in all_optim_inputs: kwargs = optim_input.kwargs + _make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs) # params will decay even if grads are empty if weight_decay != 0, # and capturable doesn't work for CPU tensors @@ -730,6 +764,12 @@ class TestOptimRenewed(TestCase): return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None for optim_input in all_optim_inputs: + kwargs = optim_input.kwargs + # See https://github.com/pytorch/pytorch/issues/117836 for Adamax + # See https://github.com/pytorch/pytorch/issues/118230 for RAdam + if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False): + continue + optimizer = optim_cls(params, **optim_input.kwargs) for _ in range(3): optimizer.step(closure) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index c2b0f0a5b55..f54d62567a5 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1519,41 +1519,42 @@ class TorchPatcher: sparse_adam, } - disabled_multi_tensor_opt_modules = { - radam, # data-dependent control flow + excluded_single_tensor = { + radam, # https://github.com/pytorch/pytorch/issues/117807 } for opt_mod in optimizer_modules: opt_name = opt_mod.__name__.split(".")[-1] - multi_tensor_fn_name = f"_multi_tensor_{opt_name}" fused_fn_name = f"_fused_{opt_name}" - if ( - hasattr(opt_mod, multi_tensor_fn_name) - and opt_mod in disabled_multi_tensor_opt_modules - ): - setattr( - opt_mod, - multi_tensor_fn_name, - disable(getattr(opt_mod, multi_tensor_fn_name)), - ) + single_tensor_fn_name = f"_single_tensor_{opt_name}" if hasattr(opt_mod, fused_fn_name): setattr( opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) ) + if ( + hasattr(opt_mod, single_tensor_fn_name) + and opt_mod in excluded_single_tensor + ): + setattr( + opt_mod, + single_tensor_fn_name, + disable(getattr(opt_mod, single_tensor_fn_name)), + ) + optimizer_classes = [ opt for opt in torch.optim.__dict__.values() if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) ] - # Note: we don't support sparsity, data-dependent control, or tracing through backwards + # Note: we don't support sparsity or tracing through backwards excluded_optimizer_classes = { torch.optim.SparseAdam, - torch.optim.RAdam, torch.optim.LBFGS, } + for opt in optimizer_classes: if opt in excluded_optimizer_classes: opt.step = disable(opt.step) diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 45c8e7a0095..4184450a2a6 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -8,11 +8,11 @@ from .optimizer import ( Optimizer, _default_to_fused_or_foreach, _differentiable_doc, + _capturable_doc, _dispatch_sqrt, _foreach_doc, _get_scalar_dtype, _get_value, - _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, ) @@ -31,6 +31,7 @@ class RAdam(Optimizer): decoupled_weight_decay: bool = False, *, foreach: Optional[bool] = None, + capturable: bool = False, differentiable: bool = False, ): if not 0.0 <= lr: @@ -49,6 +50,7 @@ class RAdam(Optimizer): eps=eps, weight_decay=weight_decay, foreach=foreach, + capturable=capturable, decoupled_weight_decay=decoupled_weight_decay, differentiable=differentiable, ) @@ -60,13 +62,13 @@ class RAdam(Optimizer): group.setdefault("foreach", None) group.setdefault("differentiable", False) group.setdefault("decoupled_weight_decay", False) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]["step"] - ) - if not step_is_tensor: - for s in state_values: - s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype()) + group.setdefault("capturable", False) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state['step']): + step_val = float(p_state["step"]) + p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable'] + else torch.tensor(step_val, dtype=_get_scalar_dtype())) def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps): has_complex = False @@ -81,7 +83,11 @@ class RAdam(Optimizer): state = self.state[p] # Lazy state initialization if len(state) == 0: - state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype()) + state['step'] = ( + torch.zeros((), dtype=_get_scalar_dtype(), device=p.device) + if group['capturable'] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like( p, memory_format=torch.preserve_format @@ -132,6 +138,7 @@ class RAdam(Optimizer): weight_decay=group["weight_decay"], eps=group["eps"], foreach=group["foreach"], + capturable=group["capturable"], differentiable=group["differentiable"], decoupled_weight_decay=group["decoupled_weight_decay"], has_complex=has_complex, @@ -201,6 +208,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm. decay as in AdamW to obtain RAdamW (default: False) {_foreach_doc} {_differentiable_doc} + {_capturable_doc} .. _On the variance of the adaptive learning rate and beyond: https://arxiv.org/abs/1908.03265 @@ -223,6 +231,7 @@ def radam( decoupled_weight_decay: bool = False, foreach: Optional[bool] = None, differentiable: bool = False, + capturable: bool = False, has_complex: bool = False, *, beta1: float, @@ -265,6 +274,7 @@ def radam( eps=eps, decoupled_weight_decay=decoupled_weight_decay, differentiable=differentiable, + capturable=capturable, has_complex=has_complex, ) @@ -283,8 +293,11 @@ def _single_tensor_radam( eps: float, differentiable: bool, decoupled_weight_decay: bool, + capturable: bool, has_complex: bool, ): + if capturable: + raise RuntimeError("capturable is not supported for single tensor radam") for i, param in enumerate(params): grad = grads[i] @@ -356,6 +369,7 @@ def _multi_tensor_radam( eps: float, decoupled_weight_decay: bool, differentiable: bool, + capturable: bool, has_complex: bool, ): @@ -364,6 +378,11 @@ def _multi_tensor_radam( assert not differentiable, "_foreach ops don't support autograd" + # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] + if not torch._utils.is_compiling() and capturable: + assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ + "If capturable=True, params and state_steps must be CUDA tensors." + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps]) for (( grouped_params, @@ -387,8 +406,21 @@ def _multi_tensor_radam( # maximum length of the approximated SMA rho_inf = 2 / (1 - beta2) - 1 # compute the length of the approximated SMA - rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) / - (1 - beta2 ** _get_value(step)) for step in grouped_state_steps] + if capturable: + bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, grouped_state_steps) + torch._foreach_mul_(bias_correction2, 2) + torch._foreach_div_(bias_correction2, bias_correction1) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, rho_inf) + rho_t_list = bias_correction2 + else: + rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) / + (1 - beta2 ** _get_value(step)) for step in grouped_state_steps] + if weight_decay != 0: if decoupled_weight_decay: @@ -405,29 +437,67 @@ def _multi_tensor_radam( # Delete the local intermediate since it won't be used anymore to save on peak memory del grouped_grads - rect = [ - _dispatch_sqrt( - (rho_t - 4) - * (rho_t - 2) - * rho_inf - / ((rho_inf - 4) * (rho_inf - 2) * rho_t) - ) - if rho_t > 5 - else 0 - for rho_t in rho_t_list - ] - unrectified = [0 if rect > 0 else 1.0 for rect in rect] + if capturable: + num = torch._foreach_sub(rho_t_list, 4) + sub2 = torch._foreach_sub(rho_t_list, 2) + torch._foreach_mul_(num, sub2) + del sub2 + torch._foreach_mul_(num, rho_inf) + rho_inf = ((rho_inf - 4) * (rho_inf - 2)) + denom = torch._foreach_mul(rho_t_list, rho_inf) + torch._foreach_div_(num, denom) + del denom + torch._foreach_sqrt_(num) + + # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 + rect = [torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)] + del num + del rho_t_list + unrect_step_size = [torch.where(rect > 0, 0.0, 1.0) for rect in rect] + torch._foreach_mul_(unrect_step_size, lr) + + bias_correction1 = torch._foreach_pow(beta1, grouped_state_steps) + torch._foreach_neg_(bias_correction1) + torch._foreach_add_(bias_correction1, 1) + + torch._foreach_div_(unrect_step_size, bias_correction1) + torch._foreach_neg_(unrect_step_size) + + bias_correction2 = torch._foreach_pow(beta2, grouped_state_steps) + torch._foreach_neg_(bias_correction2) + torch._foreach_add_(bias_correction2, 1) + torch._foreach_sqrt_(bias_correction2) + torch._foreach_mul_(bias_correction2, lr) + torch._foreach_mul_(bias_correction2, rect) + del rect + torch._foreach_neg_(bias_correction2) + torch._foreach_div_(bias_correction2, bias_correction1) + del bias_correction1 + else: + rect = [ + _dispatch_sqrt( + (rho_t - 4) + * (rho_t - 2) + * rho_inf + / ((rho_inf - 4) * (rho_inf - 2) * rho_t) + ) + if rho_t > 5 + else 0 + for rho_t in rho_t_list + ] + unrectified = [0 if rect > 0 else 1.0 for rect in rect] + + bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps] + unrect_step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)] + bias_correction2 = [ + _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1 + for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) + ] - bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps] - unrect_step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]) - bias_correction2_sqrt_times_rect_step_size = [ - _dispatch_sqrt(1 - beta2 ** _get_value(step)) * (lr * rect / bc) * -1 - for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) - ] buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) torch._foreach_add_(buffer, eps) - torch._foreach_div_(buffer, bias_correction2_sqrt_times_rect_step_size) + torch._foreach_div_(buffer, bias_correction2) torch._foreach_reciprocal_(buffer) torch._foreach_add_(buffer, unrect_step_size) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 098c59ddc01..ff376cc28a8 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -629,6 +629,26 @@ def optim_error_inputs_func_nadam(device, dtype): # Weird story bro, NAdam and RAdam do not have maximize. def optim_inputs_func_radam(device=None): + cuda_supported_configs = [ + OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"), + OptimizerInput( + params=None, + kwargs={ + "capturable": True, + "weight_decay": 0.1, + }, + desc="capturable, weight_decay", + ), + OptimizerInput( + params=None, + kwargs={ + "capturable": True, + "weight_decay": 0.1, + "decoupled_weight_decay": True, + }, + desc="capturable, weight_decay, decoupled_weight_decay", + ), + ] return [ OptimizerInput(params=None, kwargs={}, desc="default"), OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"), @@ -641,7 +661,7 @@ def optim_inputs_func_radam(device=None): kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True}, desc="decoupled_weight_decay", ), - ] + ] + (cuda_supported_configs if "cuda" in str(device) else []) def optim_error_inputs_func_radam(device, dtype): @@ -1489,6 +1509,104 @@ optim_db: List[OptimizerInfo] = [ "TestOptimRenewed", "test_deepcopy_copies_all_public_attrs", ), + DecorateInfo( + skipIfTorchDynamo( + "See https://github.com/pytorch/pytorch/issues/115607" + ), + "TestOptimRenewed", + "test_foreach_matches_forloop", + ), + DecorateInfo( + toleranceOverride( + { + # previously atol=1e-7, rtol=1e-7 + torch.float64: tol(atol=1.5e-7, rtol=1.1e-7) + } + ), + "TestOptimRenewed", + "test_foreach_matches_forloop", + ), + DecorateInfo( + skipIfTorchDynamo( + "See https://github.com/pytorch/pytorch/issues/116494" + ), + "TestOptimRenewed", + "test_state_dict_deterministic", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/115607" + ), + "TestOptimRenewed", + "test_can_load_older_state_dict", + device_type="cpu", + ), + DecorateInfo( + toleranceOverride( + { # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202 + torch.float32: tol(atol=5e-04, rtol=0.01), + } + ), + "TestOptimRenewed", + "test_mixed_device_dtype", + active_if=TEST_WITH_TORCHDYNAMO, + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_complex", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_step_is_noop_when_params_have_no_grad", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_load_nontensor_step", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_param_groups_weight_decay", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_param_groups_lr", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_step_is_noop_for_zero_grads", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_state_dict_with_cuda_params", + ), + DecorateInfo( + skipIfTorchDynamo( + "Should be fixed by https://github.com/pytorch/pytorch/issues/118230" + ), + "TestOptimRenewed", + "test_mixed_device_dtype", + ), ), ), OptimizerInfo(