Skip to content

Optimizers

create_optimizer(model, optimizer_name, lr=0.001, weight_decay=0.0, wd_ban_list=('bias', 'LayerNorm.bias', 'LayerNorm.weight'), use_lookahead=False, use_orthograd=False, **kwargs)

Build optimizer.

Parameters:

Name Type Description Default
model Module

model.

required
optimizer_name str

optimizer name.

required
lr float

learning rate.

0.001
weight_decay float

weight decay.

0.0
wd_ban_list List[str]

weight decay ban list by layer.

('bias', 'LayerNorm.bias', 'LayerNorm.weight')
use_lookahead bool

use Lookahead.

False
use_orthograd bool

use OrthoGrad.

False
**kwargs dict

optimizer parameters.

{}
Source code in pytorch_optimizer/optimizer/__init__.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def create_optimizer(
    model: nn.Module,
    optimizer_name: str,
    lr: float = 1e-3,
    weight_decay: float = 0.0,
    wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
    use_lookahead: bool = False,
    use_orthograd: bool = False,
    **kwargs,
) -> Optimizer:
    r"""Build optimizer.

    Args:
        model (nn.Module): model.
        optimizer_name (str): optimizer name.
        lr (float): learning rate.
        weight_decay (float): weight decay.
        wd_ban_list (List[str]): weight decay ban list by layer.
        use_lookahead (bool): use Lookahead.
        use_orthograd (bool): use OrthoGrad.
        **kwargs (dict): optimizer parameters.

    """
    optimizer_name = optimizer_name.lower()

    parameters = (
        get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
    )

    optimizer_class: OptimizerType = load_optimizer(optimizer_name)

    if optimizer_name == 'alig':
        optimizer = optimizer_class(parameters, max_lr=lr, **kwargs)
    elif optimizer_name in ('lomo', 'adalomo', 'adammini'):
        optimizer = optimizer_class(model, lr=lr, **kwargs)
    elif optimizer_name in ('muon', 'adamuon', 'adago'):
        warn(f'highly recommend you to manually create the {optimizer_name} manually.', UserWarning, stacklevel=1)

        optimizer = prepare_muon_parameters(model, optimizer_name, lr=lr, weight_decay=weight_decay, **kwargs)
    else:
        optimizer = optimizer_class(parameters, lr=lr, **kwargs)

    if use_orthograd:
        optimizer = OrthoGrad(optimizer, **kwargs)

    if use_lookahead:
        if optimizer_name in ('ranger', 'ranger21', 'ranger25'):
            warn(f'{optimizer} already has a Lookahead variant.', UserWarning, stacklevel=1)
            return optimizer

        optimizer = Lookahead(
            optimizer,
            k=kwargs.get('k', 5),
            alpha=kwargs.get('alpha', 0.5),
            pullback_momentum=kwargs.get('pullback_momentum', 'none'),
        )

    return optimizer

get_optimizer_parameters(model_or_parameter, weight_decay, wd_ban_list=('bias', 'LayerNorm.bias', 'LayerNorm.weight'))

Get optimizer parameters while filtering specified modules.

Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only need to input LayerNorm to exclude weight decay from the layer norm layer(s).

Parameters:

Name Type Description Default
model_or_parameter Union[Module, List]

model or parameters.

required
weight_decay float

weight decay.

required
wd_ban_list List[str]

weight decay ban list.

('bias', 'LayerNorm.bias', 'LayerNorm.weight')

Returns:

Name Type Description
ParamsT ParamsT

optimizer parameters.

Source code in pytorch_optimizer/optimizer/__init__.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def get_optimizer_parameters(
    model_or_parameter: Union[nn.Module, List],
    weight_decay: float,
    wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> ParamsT:
    r"""Get optimizer parameters while filtering specified modules.

    Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance.
    You just only need to input `LayerNorm` to exclude weight decay from the layer norm layer(s).

    Args:
        model_or_parameter (Union[nn.Module, List]): model or parameters.
        weight_decay (float): weight decay.
        wd_ban_list (List[str]): weight decay ban list.

    Returns:
        ParamsT: optimizer parameters.

    """
    banned_parameter_patterns: Set[str] = set()

    if isinstance(model_or_parameter, nn.Module):
        for module_name, module in model_or_parameter.named_modules():
            for param_name, _ in module.named_parameters(recurse=False):
                full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name
                if any(
                    banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name())
                ):
                    banned_parameter_patterns.add(full_param_name)

        model_or_parameter = list(model_or_parameter.named_parameters())
    else:
        banned_parameter_patterns.update(wd_ban_list)

    return [
        {
            'params': [
                p
                for n, p in model_or_parameter
                if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns)
            ],
            'weight_decay': weight_decay,
        },
        {
            'params': [
                p
                for n, p in model_or_parameter
                if p.requires_grad and any(nd in n for nd in banned_parameter_patterns)
            ],
            'weight_decay': 0.0,
        },
    ]

A2Grad

Bases: BaseOptimizer

Optimal Adaptive and Accelerated Stochastic Gradient Descent.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr Optional[float]

Learning rate. No needed.

None
beta float

Beta.

10.0
lips float

Lipschitz constant.

10.0
rho float

Represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.

0.5
variant str

Variant of A2Grad optimizer. One of 'uni', 'inc', or 'exp'.

'uni'
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/a2grad.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class A2Grad(BaseOptimizer):
    """Optimal Adaptive and Accelerated Stochastic Gradient Descent.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (Optional[float]): Learning rate. No needed.
        beta (float): Beta.
        lips (float): Lipschitz constant.
        rho (float): Represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.
        variant (str): Variant of A2Grad optimizer. One of 'uni', 'inc', or 'exp'.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: Optional[float] = None,
        beta: float = 10.0,
        lips: float = 10.0,
        rho: float = 0.5,
        variant: VARIANTS = 'uni',
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(lips, 'lips')
        self.validate_non_negative(rho, 'rho')
        self.validate_options(variant, 'variant', ['uni', 'inc', 'exp'])

        self.variant = variant
        self.maximize = maximize

        defaults: Defaults = {'beta': beta, 'lips': lips}
        if variant == 'exp':
            defaults.update({'rho': rho})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'A2Grad'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['alpha_k'] = 1.0
                state['v_k'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                state['avg_grad'] = grad.clone()
                state['x_k'] = p.clone()
                if self.variant == 'exp':
                    state['v_kk'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            gamma_k: float = 2.0 * group['lips'] / (group['step'] + 1)
            alpha_k_1: float = 2.0 / (group['step'] + 3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                avg_grad, v_k, x_k = state['avg_grad'], state['v_k'], state['x_k']
                avg_grad.add_(grad - avg_grad, alpha=group['step'] + 1)

                delta_k = grad.clone()
                delta_k.add_(avg_grad, alpha=-1.0)

                delta_k_sq = delta_k.pow(2).sum()

                if self.variant in ('uni', 'inc'):
                    if self.variant == 'inc':
                        v_k.mul_((group['step'] / (group['step'] + 1)) ** 2)
                    v_k.add_(delta_k_sq)
                else:
                    v_kk = state['v_kk']

                    v_kk.mul_(group['rho']).add_(delta_k_sq, alpha=1.0 - group['rho'])
                    torch.max(v_kk, v_k, out=v_k)

                h_k = v_k.sqrt()
                if self.variant != 'uni':
                    h_k.mul_(math.sqrt(group['step'] + 1))

                coefficient = -1.0 / (gamma_k + group['beta'] * h_k.item())

                x_k.add_(grad, alpha=coefficient)

                p.mul_(1.0 - alpha_k_1).add_(x_k, alpha=alpha_k_1)
                p.add_(grad, alpha=(1.0 - alpha_k_1) * state['alpha_k'] * coefficient)

                state['alpha_k'] = alpha_k_1

        return loss

AccSGD

Bases: BaseOptimizer

Accelerating Stochastic Gradient Descent For Least Squares Regression.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
kappa float

ratio of long to short step.

1000.0
xi float

statistical advantage parameter.

10.0
constant float

any small constant under 1.

0.7
weight_decay float

weight decay (L2 penalty).

0.0
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class AccSGD(BaseOptimizer):
    """Accelerating Stochastic Gradient Descent For Least Squares Regression.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        kappa (float): ratio of long to short step.
        xi (float): statistical advantage parameter.
        constant (float): any small constant under 1.
        weight_decay (float): weight decay (L2 penalty).
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        kappa: float = 1000.0,
        xi: float = 10.0,
        constant: float = 0.7,
        weight_decay: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(kappa, 'kappa')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_boundary(constant, boundary=1.0, bound_type='upper')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'kappa': kappa,
            'xi': xi,
            'constant': constant,
            'weight_decay': weight_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AccSGD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = p.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            large_lr: float = group['lr'] * group['kappa'] / group['constant']
            alpha: float = 1.0 - (group['xi'] * (group['constant'] ** 2) / group['kappa'])
            beta: float = 1.0 - alpha
            zeta: float = group['constant'] / (group['constant'] + beta)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                buf = state['momentum_buffer']
                buf.mul_((1.0 / beta) - 1.0).add_(grad, alpha=-large_lr).add_(p).mul_(beta)

                p.add_(grad, alpha=-group['lr']).mul_(zeta).add_(buf, alpha=1.0 - zeta)

        return loss

AdaBelief

Bases: BaseOptimizer

Adapting Step-sizes by the Belief in Observed Gradients.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
rectify bool

Perform the rectified update similar to RAdam.

False
n_sma_threshold int

Number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

Perform SGD update when variance of gradient is high.

True
ams_bound bool

Whether to use the AMSBound variant.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
eps float

Term added to the denominator to improve numerical stability.

1e-16
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adabelief.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class AdaBelief(BaseOptimizer):
    """Adapting Step-sizes by the Belief in Observed Gradients.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        rectify (bool): Perform the rectified update similar to RAdam.
        n_sma_threshold: Number of SMA threshold (recommended is 5).
        degenerated_to_sgd (bool): Perform SGD update when variance of gradient is high.
        ams_bound (bool): Whether to use the AMSBound variant.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        foreach: Optional[bool] = None,
        eps: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'foreach': foreach,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaBelief'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        if group.get('adanorm') or group['rectify'] or group['ams_bound']:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_vars: List[torch.Tensor],
    ) -> None:
        beta1, beta2 = group['betas']
        lr = group['lr']

        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=lr,
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        torch._foreach_lerp_(exp_avgs, grads, weight=1.0 - beta1)

        grad_residuals = torch._foreach_sub(grads, exp_avgs)

        torch._foreach_mul_(exp_avg_vars, beta2)
        torch._foreach_addcmul_(exp_avg_vars, grad_residuals, grad_residuals, value=1.0 - beta2)
        torch._foreach_add_(exp_avg_vars, group['eps'])

        de_noms = torch._foreach_sqrt(exp_avg_vars)
        torch._foreach_div_(de_noms, bias_correction2_sq)

        torch._foreach_addcdiv_(params, exp_avgs, de_noms, value=-lr)

    def _step_per_param(self, group: ParamGroup, step_size: float, n_sma: float) -> None:
        beta1, beta2 = group['betas']

        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            self.apply_weight_decay(
                p=p,
                grad=grad,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

            p, grad, exp_avg, exp_avg_var = self.view_as_real(p, grad, exp_avg, exp_avg_var)

            s_grad = self.get_adanorm_gradient(
                grad=grad,
                adanorm=group.get('adanorm', False),
                exp_grad_norm=state.get('exp_grad_adanorm', None),
                r=group.get('adanorm_r', None),
            )

            exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

            grad_residual = grad - exp_avg
            exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

            de_nom = self.apply_ams_bound(
                ams_bound=group['ams_bound'],
                exp_avg_sq=exp_avg_var,
                max_exp_avg_sq=state.get('max_exp_avg_var', None),
                eps=group['eps'],
            )

            if not group['rectify']:
                de_nom.div_(bias_correction2_sq)
                p.addcdiv_(exp_avg, de_nom, value=-step_size)
                continue

            if n_sma >= self.n_sma_threshold:
                p.addcdiv_(exp_avg, de_nom, value=-step_size)
            elif step_size > 0:
                p.add_(exp_avg, alpha=-step_size)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_var']
                )
                if params:
                    self._step_foreach(group, params, grads, state_dict['exp_avg'], state_dict['exp_avg_var'])
            else:
                self._step_per_param(group, step_size, n_sma)

        return loss

AdaBound

Bases: BaseOptimizer

Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
final_lr float

Final learning rate.

0.1
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
gamma float

Convergence speed of the bound functions.

0.001
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adabound.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class AdaBound(BaseOptimizer):
    r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        final_lr (float): Final learning rate.
        betas: Coefficients used for computing running averages of gradient and the squared Hessian trace.
        gamma (float): Convergence speed of the bound functions.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        final_lr: float = 1e-1,
        betas: Betas = (0.9, 0.999),
        gamma: float = 1e-3,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'final_lr': final_lr,
            'gamma': gamma,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'AdaBound'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group, base_lr in zip(self.param_groups, self.base_lrs):
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            final_lr: float = group['final_lr'] * group['lr'] / base_lr
            lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1))
            upper_bound: float = final_lr * (1 + 1 / (group['gamma'] * group['step']))

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                update = torch.full_like(de_nom, fill_value=step_size)
                update.div_(de_nom).clamp_(min=lower_bound, max=upper_bound).mul_(exp_avg)

                p.add_(-update)

        return loss

AdaDelta

Bases: BaseOptimizer

An Adaptive Learning Rate Method.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
rho float

Coefficient used for computing a running average of squared gradients.

0.9
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adadelta.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class AdaDelta(BaseOptimizer):
    """An Adaptive Learning Rate Method.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        rho (float): Coefficient used for computing a running average of squared gradients.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        rho: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(rho, 'rho', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'rho': rho,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaDelta'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['square_avg'] = torch.zeros_like(p)
                state['acc_delta'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            rho: float = group['rho']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                square_avg, acc_delta = state['square_avg'], state['acc_delta']

                p, grad, square_avg, acc_delta = self.view_as_real(p, grad, square_avg, acc_delta)

                square_avg.mul_(rho).addcmul_(grad, grad, value=1.0 - rho)

                std = square_avg.add(group['eps']).sqrt_()
                delta = acc_delta.add(group['eps']).sqrt_().div_(std).mul_(grad)

                acc_delta.mul_(rho).addcmul_(delta, delta, value=1.0 - rho)

                p.add_(delta, alpha=-group['lr'])

        return loss

AdaFactor

Bases: BaseOptimizer

Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.

PyTorch implementation of BigVision's AdaFactor variant

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Union[Tuple[None, float], Tuple[float, float], Tuple[float, float, float]]

Coefficients used for computing running averages of gradient and the squared Hessian trace. If beta1 is None, first momentum will be skipped. beta2 is an upper bound cap.

(0.9, 0.999)
decay_rate float

Coefficient used to compute running averages of squared gradient.

-0.8
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
clip_threshold float

Threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

Whether to use the AMSBound variant.

False
scale_parameter bool

If True, the learning rate is scaled by root-mean-square of parameter.

True
relative_step bool

If True, time-dependent learning rate is computed instead of external learning rate.

True
warmup_init bool

Time-dependent learning rate computation depends on whether warm-up initialization is being used.

False
eps1 float

Term added to the denominator to improve numerical stability.

1e-30
eps2 float

Term added to the denominator to improve numerical stability.

0.001
momentum_dtype dtype

Type of momentum variable. In the ViT paper, it was observed that storing momentum in half-precision (bfloat16 type) does not affect training dynamics and reduces optimizer overhead from 2-fold to 1.5-fold.

bfloat16
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adafactor.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class AdaFactor(BaseOptimizer):
    """Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.

    PyTorch implementation of BigVision's AdaFactor variant

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Union[Tuple[None, float], Tuple[float, float], Tuple[float, float, float]]): Coefficients used for
            computing running averages of gradient and the squared Hessian trace.
            If beta1 is None, first momentum will be skipped. beta2 is an upper bound cap.
        decay_rate (float): Coefficient used to compute running averages of squared gradient.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        clip_threshold (float): Threshold of root-mean-square of final gradient update.
        ams_bound (bool): Whether to use the AMSBound variant.
        scale_parameter (bool): If True, the learning rate is scaled by root-mean-square of parameter.
        relative_step (bool): If True, time-dependent learning rate is computed instead of external learning rate.
        warmup_init (bool): Time-dependent learning rate computation depends on whether warm-up initialization is
            being used.
        eps1 (float): Term added to the denominator to improve numerical stability.
        eps2 (float): Term added to the denominator to improve numerical stability.
        momentum_dtype (torch.dtype): Type of momentum variable. In the ViT paper, it was observed that storing
            momentum in half-precision (bfloat16 type) does not affect training dynamics and reduces optimizer
            overhead from 2-fold to 1.5-fold.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: Optional[float] = 1e-3,
        betas: Union[Tuple[None, float], Tuple[float, float], Tuple[float, float, float]] = (0.9, 0.999),
        decay_rate: float = -0.8,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        scale_parameter: bool = True,
        relative_step: bool = True,
        warmup_init: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-3,
        momentum_dtype: torch.dtype = torch.bfloat16,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.decay_rate = decay_rate
        self.clip_threshold = clip_threshold
        self.eps1: float = eps1 if momentum_dtype != torch.float16 else 1e-7
        self.eps2 = eps2
        self.momentum_dtype = momentum_dtype
        self.foreach = foreach
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'scale_parameter': scale_parameter,
            'relative_step': relative_step,
            'warmup_init': warmup_init,
            'eps1': eps1,
            'eps2': eps2,
            'foreach': foreach,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaFactor'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        beta1: float = kwargs.get('beta1', 0.9)

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_shape: Tuple[int, ...] = grad.shape
            factored: bool = self.get_options(grad_shape)

            if len(state) == 0:
                state['RMS'] = 0.0

                if beta1 is not None:
                    state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

    @staticmethod
    def get_relative_step_size(lr: float, step: int, relative_step: bool, warmup_init: bool) -> float:
        if not relative_step:
            return lr

        min_step: float = 1e-6 * step if warmup_init else 1e-2
        return min(min_step, 1.0 / math.sqrt(step))

    def get_lr(
        self,
        relative_step_size: Union[torch.Tensor, float],
        rms: Union[List[torch.Tensor], torch.Tensor, float],
        scale_parameter: bool,
    ) -> Union[Sequence[torch.Tensor], torch.Tensor, float]:
        r"""Get the learning rate(s)."""
        if not scale_parameter:
            return relative_step_size

        if not isinstance(rms, Sequence):
            return max(self.eps2, rms) * relative_step_size

        lrs = torch._foreach_maximum(rms, self.eps2)
        torch._foreach_mul_(lrs, relative_step_size)

        return lrs

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        if group.get('cautious'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sq_rows: List[torch.Tensor],
        exp_avg_sq_cols: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        exp_avg_sq_hats: List[torch.Tensor],
        beta1: float,
        beta2_t: float,
        relative_step_size: float,
    ) -> None:
        bias_correction2: float = 1.0 - beta2_t

        if self.maximize:
            torch._foreach_neg_(grads)

        rms_values = self.get_rms(params)
        lrs = self.get_lr(relative_step_size, rms_values, group['scale_parameter'])

        updates = torch._foreach_pow(grads, 2)
        torch._foreach_add_(updates, self.eps1)

        factored_offsets, non_factored_offsets = [], []
        factored_updates, non_factored_updates = [], []
        for i, grad in enumerate(grads):
            if self.get_options(grad.shape):
                factored_updates.append(updates[i])
                factored_offsets.append(i)
            else:
                non_factored_updates.append(updates[i])
                non_factored_offsets.append(i)

        if factored_updates:
            row_means, col_means = [], []
            for factored_update in factored_updates:
                row_means.append(factored_update.mean(dim=-1))
                col_means.append(factored_update.mean(dim=-2))

            torch._foreach_lerp_(exp_avg_sq_rows, row_means, weight=bias_correction2)
            torch._foreach_lerp_(exp_avg_sq_cols, col_means, weight=bias_correction2)

            self.approximate_sq_grad(exp_avg_sq_rows, exp_avg_sq_cols, factored_updates)

        if non_factored_updates:
            torch._foreach_lerp_(exp_avg_sqs, non_factored_updates, weight=bias_correction2)

            non_factored_updates = foreach_rsqrt(exp_avg_sqs)

        updates = [None] * len(grads)

        for offset, update in zip(factored_offsets, factored_updates):
            updates[offset] = update

        for offset, update in zip(non_factored_offsets, non_factored_updates):
            updates[offset] = update

        if group['ams_bound']:
            inv_updates = torch._foreach_reciprocal(updates)
            torch._foreach_maximum_(exp_avg_sq_hats, inv_updates)

            updates = foreach_rsqrt(torch._foreach_div(exp_avg_sq_hats, bias_correction2))

        torch._foreach_mul_(updates, grads)

        rms_values = self.get_rms(updates)
        torch._foreach_div_(rms_values, self.clip_threshold)
        torch._foreach_clamp_max_(rms_values, 1.0)

        torch._foreach_div_(updates, rms_values)
        torch._foreach_mul_(updates, lrs)

        if beta1 is not None:
            is_dtype_different: bool = self.momentum_dtype != grads[0].dtype
            if is_dtype_different:
                updates = [update.to(self.momentum_dtype) for update in updates]

            torch._foreach_lerp_(exp_avgs, updates, weight=1.0 - beta1)

            if is_dtype_different:
                updates = [exp_avg.to(grads[0].dtype) for exp_avg in exp_avgs]
            else:
                torch._foreach_copy_(updates, exp_avgs)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=lrs,
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        torch._foreach_sub_(params, updates)

    def _step_per_param(self, group: ParamGroup, beta1: float, beta2_t: float, relative_step_size: float) -> None:
        bias_correction2: float = 1.0 - beta2_t
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            factored: bool = self.get_options(grad.shape)

            state['RMS'] = self.get_rms(p)

            lr = self.get_lr(relative_step_size, state['RMS'], group['scale_parameter'])

            # NOTE(kozistr): adding `eps1` here instead of clipping max by eps1 later
            update = grad.square().add_(self.eps1)

            if factored:
                exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                exp_avg_sq_row.lerp_(update.mean(dim=-1), weight=bias_correction2)
                exp_avg_sq_col.lerp_(update.mean(dim=-2), weight=bias_correction2)

                self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
            else:
                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.lerp_(update, weight=bias_correction2)
                torch.rsqrt(exp_avg_sq, out=update)

            if group['ams_bound']:
                exp_avg_sq_hat = state['exp_avg_sq_hat']
                torch.max(exp_avg_sq_hat, 1.0 / update, out=exp_avg_sq_hat)
                torch.rsqrt(exp_avg_sq_hat / bias_correction2, out=update)

            update.mul_(grad)

            factor = self.get_rms(update).div_(self.clip_threshold).clamp_max_(1.0)
            update.div_(factor).mul_(lr)

            if beta1 is not None:
                exp_avg = state['exp_avg']
                if self.momentum_dtype != grad.dtype:
                    exp_avg.lerp_(update.to(self.momentum_dtype), weight=1.0 - beta1)
                    update = exp_avg.to(grad.dtype)
                else:
                    exp_avg.lerp_(update, weight=1.0 - beta1)
                    update = exp_avg.clone()

                if group.get('cautious'):
                    self.apply_cautious(update, grad)

            self.apply_weight_decay(
                p=p,
                grad=None,
                lr=lr,
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            p.add_(-update)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2_cap = group['betas']

            self.init_group(group, beta1=beta1)
            group['step'] += 1

            beta2_t: float = min(beta2_cap, 1.0 - math.pow(group['step'], self.decay_rate))

            relative_step_size: float = self.get_relative_step_size(
                lr=group['lr'],
                step=group['step'],
                relative_step=group['relative_step'],
                warmup_init=group['warmup_init'],
            )

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group,
                    self.state,
                    state_keys=['exp_avg', 'exp_avg_sq_row', 'exp_avg_sq_col', 'exp_avg_sq', 'exp_avg_sq_hat'],
                )
                if params:
                    self._step_foreach(
                        group,
                        params,
                        grads,
                        state_dict['exp_avg'],
                        state_dict['exp_avg_sq_row'],
                        state_dict['exp_avg_sq_col'],
                        state_dict['exp_avg_sq'],
                        state_dict['exp_avg_sq_hat'],
                        beta1,
                        beta2_t,
                        relative_step_size,
                    )
            else:
                self._step_per_param(group, beta1, beta2_t, relative_step_size)

        return loss

get_lr(relative_step_size, rms, scale_parameter)

Get the learning rate(s).

Source code in pytorch_optimizer/optimizer/adafactor.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_lr(
    self,
    relative_step_size: Union[torch.Tensor, float],
    rms: Union[List[torch.Tensor], torch.Tensor, float],
    scale_parameter: bool,
) -> Union[Sequence[torch.Tensor], torch.Tensor, float]:
    r"""Get the learning rate(s)."""
    if not scale_parameter:
        return relative_step_size

    if not isinstance(rms, Sequence):
        return max(self.eps2, rms) * relative_step_size

    lrs = torch._foreach_maximum(rms, self.eps2)
    torch._foreach_mul_(lrs, relative_step_size)

    return lrs

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/adafactor.py
165
166
167
168
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

AdaGC

Bases: BaseOptimizer

Improving Training Stability for Large Language Model Pretraining.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
beta float

Smoothing coefficient for the exponential moving average (EMA).

0.98
lambda_abs float

Absolute clipping threshold to prevent unstable updates from gradient explosions.

1.0
lambda_rel float

Relative clipping threshold to prevent unstable updates from gradient explosions.

1.05
warmup_steps int

Number of warmup steps.

100
weight_decay float

Weight decay (L2 penalty).

0.1
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adagc.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class AdaGC(BaseOptimizer):
    """Improving Training Stability for Large Language Model Pretraining.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        beta (float): Smoothing coefficient for the exponential moving average (EMA).
        lambda_abs (float): Absolute clipping threshold to prevent unstable updates from gradient explosions.
        lambda_rel (float): Relative clipping threshold to prevent unstable updates from gradient explosions.
        warmup_steps (int): Number of warmup steps.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        beta: float = 0.98,
        lambda_abs: float = 1.0,
        lambda_rel: float = 1.05,
        warmup_steps: int = 100,
        weight_decay: float = 1e-1,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(beta, 'beta', 0.0, 1.0, '[)')
        self.validate_positive(lambda_abs, 'lambda_abs')
        self.validate_positive(lambda_rel, 'lambda_rel')
        self.validate_non_negative(warmup_steps, 'warmup_steps')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'beta': beta,
            'lambda_abs': lambda_abs,
            'lambda_rel': lambda_rel,
            'warmup_steps': warmup_steps,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaGC'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)
                state['gamma'] = torch.empty((1,), device=grad.device, dtype=grad.dtype)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq, gamma = state['exp_avg'], state['exp_avg_sq'], state['gamma']

                if group['step'] < group['warmup_steps']:
                    grad_norm = get_global_gradient_norm(self.param_groups).add_(group['eps'])

                    h_t = min(group['lambda_abs'] / grad_norm, 1.0)
                    g_hat = grad.mul(h_t)

                    g_hat_norm = g_hat.norm()

                    gamma.copy_(g_hat_norm if group['step'] == 1 else min(gamma, g_hat_norm))
                else:
                    h_t = min(group['lambda_rel'] * gamma / grad.norm(), 1.0)
                    g_hat = grad.mul(h_t)

                    gamma.mul_(group['beta']).add_(g_hat.norm(), alpha=1.0 - group['beta'])

                exp_avg.mul_(beta1).add_(g_hat, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(g_hat, g_hat, value=1.0 - beta2)

                update = (exp_avg / bias_correction1) / exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.add_(update, alpha=-group['lr'])

        return loss

AdaGO

Bases: BaseOptimizer

AdaGrad Meets Muon: Adaptive Stepsizes for Orthogonal Updates.

Parameters:

Name Type Description Default
params ParamsT

The parameters to be optimized by Muon.

required
lr float

Learning rate.

0.05
momentum float

The momentum used by the internal SGD.

0.95
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
nesterov bool

Whether to use nesterov momentum.

False
gamma float

Gamma factor. Empirically, AdaGO performs robustly across a wide range of gamma values.

10.0
eps float

Epsilon value. Lower bound eps > 0 on the stepsizes.

0.0005
ns_steps int

The number of Newton-Schulz iterations to run. (5 is probably always enough)

5
ns_coeffs NewtonSchulzWeights

Newton-Schulz coefficients or preset name.

'original'
use_adjusted_lr bool

Whether to use adjusted learning rate, which is from the Moonlight. Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

False
adamw_lr float

The learning rate for the internal AdamW.

0.0003
adamw_betas tuple

The betas for the internal AdamW.

(0.9, 0.95)
adamw_wd float

The weight decay for the internal AdamW.

0.0
adamw_eps float

The epsilon for the internal AdamW.

1e-10
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Example

from pytorch_optimizer import AdaGO

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2] hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2] non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

param_groups = [ dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True), dict( params=hidden_gains_biases + non_hidden_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01, use_muon=False, ), ]

optimizer = AdaGO(param_groups)

Source code in pytorch_optimizer/optimizer/muon.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
class AdaGO(BaseOptimizer):
    """AdaGrad Meets Muon: Adaptive Stepsizes for Orthogonal Updates.

    Args:
        params (ParamsT): The parameters to be optimized by Muon.
        lr (float): Learning rate.
        momentum (float): The momentum used by the internal SGD.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        nesterov (bool): Whether to use nesterov momentum.
        gamma (float): Gamma factor. Empirically, AdaGO performs robustly across a wide range of gamma values.
        eps (float): Epsilon value. Lower bound eps > 0 on the stepsizes.
        ns_steps (int): The number of Newton-Schulz iterations to run. (5 is probably always enough)
        ns_coeffs (NewtonSchulzWeights): Newton-Schulz coefficients or preset name.
        use_adjusted_lr (bool): Whether to use adjusted learning rate, which is from the Moonlight.
            Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
        adamw_lr (float): The learning rate for the internal AdamW.
        adamw_betas (tuple): The betas for the internal AdamW.
        adamw_wd (float): The weight decay for the internal AdamW.
        adamw_eps (float): The epsilon for the internal AdamW.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    Example:
        from pytorch_optimizer import AdaGO

        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

        param_groups = [
            dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True),
            dict(
                params=hidden_gains_biases + non_hidden_params,
                lr=3e-4,
                betas=(0.9, 0.95),
                weight_decay=0.01,
                use_muon=False,
            ),
        ]

        optimizer = AdaGO(param_groups)

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 5e-2,
        momentum: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        gamma: float = 10.0,
        eps: float = 5e-4,
        v: float = 1e-6,
        nesterov: bool = False,
        ns_steps: int = 5,
        ns_coeffs: NewtonSchulzWeights = 'original',
        use_adjusted_lr: bool = False,
        adamw_lr: float = 3e-4,
        adamw_betas: Betas = (0.9, 0.95),
        adamw_wd: float = 0.0,
        adamw_eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(adamw_lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_positive(ns_steps, 'ns_steps')
        self.validate_positive(gamma, 'gamma')
        self.validate_positive(eps, 'eps')
        self.validate_positive(v, 'v')
        self.validate_betas(adamw_betas)
        self.validate_non_negative(adamw_wd, 'adamw_wd')
        self.validate_non_negative(adamw_eps, 'adamw_eps')
        ns_coeffs = get_newton_schulz_weights(ns_coeffs)

        self.maximize = maximize

        for group in params:
            group = cast(ParamGroup, group)
            if 'use_muon' not in group:
                raise ValueError('`use_muon` must be set.')

            if group['use_muon']:
                group['lr'] = group.get('lr', lr)
                group['momentum'] = group.get('momentum', momentum)
                group['nesterov'] = group.get('nesterov', nesterov)
                group['weight_decay'] = group.get('weight_decay', weight_decay)
                group['ns_steps'] = group.get('ns_steps', ns_steps)
                group['ns_coeffs'] = get_newton_schulz_weights(group.get('ns_coeffs', ns_coeffs))
                group['gamma'] = group.get('gamma', gamma)
                group['eps'] = group.get('eps', eps)
                group['v'] = group.get('v', v)
                group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr)
            else:
                group['lr'] = group.get('lr', adamw_lr)
                group['betas'] = group.get('betas', adamw_betas)
                group['eps'] = group.get('eps', adamw_eps)
                group['weight_decay'] = group.get('weight_decay', adamw_wd)

            group['weight_decouple'] = group.get('weight_decouple', weight_decouple)

        super().__init__(params, kwargs)

    def __str__(self) -> str:
        return 'AdaGO'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['use_muon']:
                    state['momentum_buffer'] = torch.zeros_like(p)
                    state['v'] = torch.tensor(group['v'], dtype=p.dtype, device=p.device)
                else:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                if group['use_muon']:
                    buf, v = state['momentum_buffer'], state['v']
                    buf.lerp_(grad, weight=1.0 - group['momentum'])

                    v.add_(min(grad.norm(p=2.0).pow(2), group['gamma'] ** 2))

                    update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf
                    if update.ndim > 2:
                        update = update.view(len(update), -1)

                    update = zero_power_via_newton_schulz_5(
                        update, num_steps=group['ns_steps'], weights=group['ns_coeffs']
                    )

                    if group.get('cautious'):
                        self.apply_cautious(update, grad)

                    lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])

                    p.add_(
                        update.reshape(p.shape),
                        alpha=-max(group['eps'], lr * min(grad.norm(2), group['gamma']) / v).item(),
                    )
                else:
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                    beta1, beta2 = group['betas']

                    bias_correction1: float = self.debias(beta1, group['step'])
                    bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

                    exp_avg.lerp_(grad, weight=1.0 - beta1)
                    exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2)

                    de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(bias_correction2_sq)

                    p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

AdaHessian

Bases: BaseOptimizer

An Adaptive Second Order Optimizer for Machine Learning.

Requires loss.backward(create_graph=True) in order to calculate Hessians.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.1
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
hessian_power float

Exponent applied to the Hessian trace for scaling updates.

1.0
update_period int

Number of steps after which to apply the Hessian approximation.

1
num_samples int

Number of times to sample z when approximating the Hessian trace.

1
hessian_distribution HutchinsonG

Type of distribution used to initialize the Hutchinson trace estimator.

'rademacher'
eps float

Term added to the denominator to improve numerical stability.

1e-16
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adahessian.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class AdaHessian(BaseOptimizer):
    """An Adaptive Second Order Optimizer for Machine Learning.

    Requires `loss.backward(create_graph=True)` in order to calculate Hessians.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        hessian_power (float): Exponent applied to the Hessian trace for scaling updates.
        update_period (int): Number of steps after which to apply the Hessian approximation.
        num_samples (int): Number of times to sample `z` when approximating the Hessian trace.
        hessian_distribution (HutchinsonG): Type of distribution used to initialize the Hutchinson trace estimator.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        hessian_power: float = 1.0,
        update_period: int = 1,
        num_samples: int = 1,
        hessian_distribution: HutchinsonG = 'rademacher',
        eps: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')
        self.validate_range(hessian_power, 'Hessian Power', 0, 1, range_type='(]')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'hessian_power': hessian_power,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaHessian'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_hessian_diag_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None, hessian: Optional[List[torch.Tensor]] = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(group.get('adam_debias', False), group['lr'], bias_correction1)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1.0 - beta2)

                de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(group['hessian_power'] / 2).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

Adai

Bases: BaseOptimizer

Disentangling the Effects of Adaptive Learning Rate and Momentum.

Parameters:

Name Type Description Default
params ParamsT

(ParamsT). Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.1, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
stable_weight_decay bool

Perform stable weight decay.

False
dampening float

Dampening for momentum. When dampening < 1, it exhibits adaptive-moment behavior.

1.0
eps float

Term added to the denominator to improve numerical stability.

0.001
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adai.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
class Adai(BaseOptimizer):
    """Disentangling the Effects of Adaptive Learning Rate and Momentum.

    Args:
        params: (ParamsT). Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        stable_weight_decay (bool): Perform stable weight decay.
        dampening (float): Dampening for momentum. When dampening < 1, it exhibits adaptive-moment behavior.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.1, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        stable_weight_decay: bool = False,
        dampening: float = 1.0,
        eps: float = 1e-3,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'stable_weight_decay': stable_weight_decay,
            'dampening': dampening,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adai'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['beta1_prod'] = torch.ones_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            _, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                if not group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError

        exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size

        for group in self.param_groups:
            beta0, beta2 = group['betas']

            beta0_dp: float = math.pow(beta0, 1.0 - group['dampening'])
            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg_sq_hat = exp_avg_sq / bias_correction2
                beta1 = (
                    1.0
                    - (exp_avg_sq_hat / exp_avg_sq_hat_mean).pow_(1.0 / (3.0 - 2.0 * group['dampening'])).mul_(beta0)
                ).clamp_(0.0, 1.0 - group['eps'])
                beta3 = (1.0 - beta1).pow_(group['dampening'])

                beta1_prod = state['beta1_prod']
                beta1_prod.mul_(beta1)

                exp_avg.mul_(beta1).addcmul_(beta3, grad)
                exp_avg_hat = exp_avg.div(1.0 - beta1_prod).mul_(beta0_dp)

                p.add_(exp_avg_hat, alpha=-group['lr'])

        return loss

Adalite

Bases: BaseOptimizer

Adalite optimizer.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
g_norm_min float

Minimum gradient norm threshold.

1e-10
ratio_min float

Minimum ratio value for adaptive adjustment.

0.0001
tau float

Time constant controlling parameter smoothing or decay behavior.

1.0
eps1 float

Term added to the denominator to improve numerical stability.

1e-06
eps2 float

Additional term added to the denominator for extra numerical stability.

1e-10
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adalite.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class Adalite(BaseOptimizer):
    r"""Adalite optimizer.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        g_norm_min (float): Minimum gradient norm threshold.
        ratio_min (float): Minimum ratio value for adaptive adjustment.
        tau (float): Time constant controlling parameter smoothing or decay behavior.
        eps1 (float): Term added to the denominator to improve numerical stability.
        eps2 (float): Additional term added to the denominator for extra numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 1e-2,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        g_norm_min: float = 1e-10,
        ratio_min: float = 1e-4,
        tau: float = 1.0,
        eps1: float = 1e-6,
        eps2: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'g_norm_min': g_norm_min,
            'ratio_min': ratio_min,
            'tau': tau,
            'eps1': eps1,
            'eps2': eps2,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adalite'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if len(p.shape) < 2:
                    state['m_avg'] = torch.zeros_like(p)
                    state['v_avg'] = torch.zeros_like(p)
                else:
                    state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
                    state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

                    state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
                    state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
                    state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if sum(grad.shape) > 1:
                    trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min'])
                    grad.mul_(trust_ratio)

                if len(grad.shape) < 2:
                    m = state['m_avg']
                    v = state['v_avg']
                else:
                    r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :]
                    v = (r * c) / r.sum().clamp(min=group['eps2'])
                    m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r']

                m.lerp_(grad, 1.0 - beta1)
                v.lerp_((grad - m).square(), 1.0 - beta2)

                v_avg = v / (1.0 - beta2 ** group['step'])

                if len(grad.shape) == 2:
                    imp_c = softmax(v.mean(dim=1), dim=0)[:, None]
                    imp_r = softmax(v.mean(dim=0), dim=0)[None, :]
                    m.lerp_(grad, 1.0 - imp_c * imp_r)

                u = m.lerp(grad, 1.0 - beta1)

                if len(grad.shape) < 2:
                    state['m_avg'] = m
                    state['v_avg'] = v
                else:
                    state['v_avg_0'] = v.sum(dim=1)
                    state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2'])

                    imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None]
                    imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :]

                    c = ((m * imp_r).sum(dim=1))[:, None]
                    r = ((m * imp_c).sum(dim=0))[None, :]

                    s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2'])

                    state['m_avg_c'] = c
                    state['m_avg_r'] = r
                    state['m_avg_u'] = s

                u.div_((v_avg + group['eps1']).sqrt())

                u = u.reshape(p.shape)
                u.add_(p, alpha=group['weight_decay'])

                p.add_(u, alpha=-group['lr'])

        return loss

AdaLOMO

Bases: BaseOptimizer

Low-memory Optimization with Adaptive Learning Rate.

Parameters:

Name Type Description Default
model Module

PyTorch model.

required
lr float

Learning rate.

0.001
weight_decay float

Weight decay (L2 penalty).

0.0
loss_scale float

Loss scale.

2.0 ** 10
clip_threshold float

Threshold of root-mean-square of final gradient update.

1.0
decay_rate float

Coefficient used to compute running averages of square gradient.

-0.8
clip_grad_norm Optional[float]

Clip gradient norm.

None
clip_grad_value Optional[float]

Clip gradient value.

None
eps1 float

Term added to the denominator to improve numerical stability.

1e-30
eps2 float

Term added to the denominator to improve numerical stability.

0.001
Source code in pytorch_optimizer/optimizer/lomo.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
class AdaLOMO(BaseOptimizer):
    """Low-memory Optimization with Adaptive Learning Rate.

    Args:
        model (nn.Module): PyTorch model.
        lr (float): Learning rate.
        weight_decay (float): Weight decay (L2 penalty).
        loss_scale (float): Loss scale.
        clip_threshold (float): Threshold of root-mean-square of final gradient update.
        decay_rate (float): Coefficient used to compute running averages of square gradient.
        clip_grad_norm (Optional[float]): Clip gradient norm.
        clip_grad_value (Optional[float]): Clip gradient value.
        eps1 (float): Term added to the denominator to improve numerical stability.
        eps2 (float): Term added to the denominator to improve numerical stability.

    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        loss_scale: float = 2.0 ** 10,
        clip_threshold: float = 1.0,
        decay_rate: float = -0.8,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
        eps1: float = 1e-30,
        eps2: float = 1e-3,
        **kwargs,
    ) -> None:  # fmt: skip
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(loss_scale, 'loss_scale')
        self.validate_non_negative(clip_threshold, 'clip_threshold')
        self.validate_non_negative(clip_grad_norm, 'clip_grad_norm')
        self.validate_non_negative(clip_grad_value, 'clip_grad_value')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_scale = loss_scale
        self.clip_threshold = clip_threshold
        self.decay_rate = decay_rate
        self.clip_grad_norm = clip_grad_norm
        self.clip_grad_value = clip_grad_value
        self.eps1 = eps1
        self.eps2 = eps2

        self.num_steps: int = 0
        self.gather_norm: bool = False
        self.grad_norms: List[torch.Tensor] = []
        self.clip_coef: Optional[float] = None

        self.local_rank: int = int(os.environ.get('LOCAL_RANK', '0'))
        self.zero3_enabled: bool = is_deepspeed_zero3_enabled()

        self.grad_func: Callable[[Any], Any] = self.fuse_update_zero3() if self.zero3_enabled else self.fuse_update()

        self.exp_avg_sq = {}
        self.exp_avg_sq_row = {}
        self.exp_avg_sq_col = {}

        self.initialize_states()

        defaults: Defaults = {
            'lr': lr,
            'weight_decay': weight_decay,
            'clip_grad_norm': clip_grad_norm,
            'clip_grad_value': clip_grad_value,
            'eps1': eps1,
            'eps2': eps2,
        }

        super().__init__(self.model.parameters(), defaults)

    def __str__(self) -> str:
        return 'AdaLOMO'

    def initialize_states(self) -> None:
        for n, p in self.model.named_parameters():
            if self.zero3_enabled:  # pragma: no cover
                if len(p.ds_shape) == 1:
                    self.exp_avg_sq[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32, device=p.device)
                else:
                    self.exp_avg_sq_row[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32, device=p.device)
                    self.exp_avg_sq_col[n] = torch.zeros(p.ds_shape[1], dtype=torch.float32, device=p.device)
            elif len(p.shape) == 1:
                self.exp_avg_sq[n] = torch.zeros(p.shape[0], dtype=torch.float32, device=p.device)
            else:
                self.exp_avg_sq_row[n] = torch.zeros(p.shape[0], dtype=torch.float32, device=p.device)
                self.exp_avg_sq_col[n] = torch.zeros(p.shape[1], dtype=torch.float32, device=p.device)

            if p.requires_grad:
                p.register_hook(self.grad_func)

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    def fuse_update(self) -> Callable[[Any], Any]:
        @torch.no_grad()
        def func(x: Any) -> Any:
            for n, p in self.model.named_parameters():
                if not p.requires_grad or p.grad is None:
                    continue

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scale:
                    grad_fp32.div_(self.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    if self.clip_grad_value is not None and self.clip_grad_value > 0.0:
                        grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is not None:
                        grad_fp32.mul_(self.clip_coef)

                    beta2_t: float = 1.0 - math.pow(
                        self.num_steps, self.decay_rate if self.num_steps > 0 else -self.decay_rate
                    )

                    update = grad_fp32.pow(2).add_(self.eps1)

                    if len(p.shape) > 1:
                        self.exp_avg_sq_row[n].mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                        self.exp_avg_sq_col[n].mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                        self.approximate_sq_grad(self.exp_avg_sq_row[n], self.exp_avg_sq_col[n], update)
                        update.mul_(grad_fp32)
                    else:
                        self.exp_avg_sq[n].mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                        update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32)

                    factor = cast(torch.Tensor, self.get_rms(update)).div_(self.clip_threshold).clamp_min_(1.0)
                    update.div_(factor)

                    p_fp32 = p.to(torch.float32)
                    p_rms = torch.norm(p_fp32, 2.0) / math.sqrt(p.numel())

                    lr = self.lr * max(self.eps2, p_rms)

                    self.apply_weight_decay(
                        p,
                        grad_fp32,
                        lr,
                        self.weight_decay,
                        weight_decouple=True,
                        fixed_decay=False,
                    )

                    p_fp32.add_(grad_fp32, alpha=-lr)
                    p.copy_(p_fp32)

            return x

        return func

    def fuse_update_zero3(self) -> Callable[[Any], Any]:  # pragma: no cover
        @torch.no_grad()
        def func(x: torch.Tensor) -> torch.Tensor:
            for n, p in self.model.named_parameters():
                if p.grad is None:
                    continue

                all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scale:
                    grad_fp32.div_(self.loss_scale)

                start: int = 0
                end: int = grad_fp32.numel()
                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    partition_size: int = p.ds_tensor.numel()
                    start = partition_size * self.local_rank
                    end = min(start + partition_size, grad_fp32.numel())

                if self.clip_grad_value is not None:
                    grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
                    grad_fp32.mul_(self.clip_coef)

                beta2_t: float = 1.0 - math.pow(
                    self.num_steps, self.decay_rate if self.num_steps > 0 else -self.decay_rate
                )

                update = grad_fp32.pow(2).add_(self.eps1)

                if len(p.ds_shape) > 1:
                    self.exp_avg_sq_row[n].mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                    self.exp_avg_sq_col[n].mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                    self.approximate_sq_grad(self.exp_avg_sq_row[n], self.exp_avg_sq_col[n], update)
                    update.mul_(grad_fp32)
                else:
                    self.exp_avg_sq[n].mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                    update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32)

                factor = cast(torch.Tensor, self.get_rms(update)).div_(self.clip_threshold).clamp_min_(1.0)
                update.div_(factor)

                one_dim_update = update.view(-1)
                partitioned_update = one_dim_update.narrow(0, start, end - start)

                param_fp32 = p.ds_tensor.to(torch.float32)
                partitioned_p = param_fp32.narrow(0, 0, end - start)

                p_rms = torch.norm(partitioned_p, 2.0).pow_(2)
                all_reduce(p_rms, op=ReduceOp.SUM)
                p_rms.div_(p.ds_numel).sqrt_()

                lr = self.lr * max(self.eps2, p_rms)

                self.apply_weight_decay(
                    p=partitioned_p,
                    grad=grad_fp32,
                    lr=lr,
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                partitioned_p.add_(partitioned_update, alpha=-lr)

                p.ds_tensor[:end - start] = partitioned_p  # fmt: skip

            return x

        return func

    def fused_backward(self, loss, lr: float) -> None:
        self.lr = lr

        if self.loss_scale:
            loss = loss * self.loss_scale

        self.num_steps += 1

        loss.backward()

        self.grad_func(0)

    def grad_norm(self, loss) -> None:
        self.gather_norm = True
        self.grad_norms = []

        if self.loss_scale:
            loss = loss * self.loss_scale

        loss.backward(retain_graph=True)

        self.grad_func(0)

        with torch.no_grad():
            self.grad_norms = torch.stack(self.grad_norms)

            total_norm = torch.norm(self.grad_norms, 2.0)
            self.clip_coef = torch.clamp(float(self.clip_grad_norm) / (total_norm + 1e-6), max=1.0)

        self.gather_norm = False

AdaMax

Bases: BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamax.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class AdaMax(BaseOptimizer):
    """An Adaptive and Momental Bound Method for Stochastic Learning.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMax'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_inf'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_inf = state['exp_avg'], state['exp_inf']

                p, grad, exp_avg, exp_inf = self.view_as_real(p, grad, exp_avg, exp_inf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                norm_buf = torch.cat(
                    (exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(group['eps']).unsqueeze_(0)),
                    dim=0,
                )
                torch.max(norm_buf, dim=0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

                p.addcdiv_(exp_avg, exp_inf, value=-step_size)

        return loss

AdamC

Bases: BaseOptimizer

Why Gradients Rapidly Increase Near the End of Training.

Set normalized=True for LayerNorm and BatchNorm layers.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamc.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class AdamC(BaseOptimizer):
    """Why Gradients Rapidly Increase Near the End of Training.

    Set `normalized=True` for LayerNorm and BatchNorm layers.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize
        self.max_lr: float = lr

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamC'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            wd_step_size: float = group['lr'] if not group.get('normalized') else (group['lr'] ** 2) / self.max_lr

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=wd_step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

AdamG

Bases: BaseOptimizer

Towards Stability of Parameter-free Optimization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.95, 0.999, 0.95)
p float

The p value in the numerator function s(x) = p * x^q.

0.2
q float

The q value in the numerator function s(x) = p * x^q.

0.24
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamg.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class AdamG(BaseOptimizer):
    """Towards Stability of Parameter-free Optimization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        p (float): The p value in the numerator function `s(x) = p * x^q`.
        q (float): The q value in the numerator function `s(x) = p * x^q`.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        betas: Betas = (0.95, 0.999, 0.95),
        p: float = 0.2,
        q: float = 0.24,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(p, 'p')
        self.validate_positive(q, 'q')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.p = p
        self.q = q
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamG'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['m'] = torch.zeros_like(p)
                state['v'] = torch.zeros_like(p)
                state['r'] = torch.zeros_like(p)

    def s(self, p: torch.Tensor) -> torch.Tensor:
        r"""Numerator function f(x) = p * x^q."""
        return p.pow(self.q).mul_(self.p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = min(group['lr'], 1.0 / math.sqrt(group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                m, v, r = state['m'], state['v'], state['r']

                p, grad, m, v, r = self.view_as_real(p, grad, m, v, r)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                r.mul_(beta3).add_(self.s(v), alpha=1.0 - beta3)
                m.mul_(beta1).addcmul_(r, grad, value=1.0 - beta1)

                update = (m / bias_correction1) / (v / bias_correction2).sqrt_().add_(group['eps'])

                p.add_(update, alpha=-step_size)

        return loss

s(p)

Numerator function f(x) = p * x^q.

Source code in pytorch_optimizer/optimizer/adamg.py
86
87
88
def s(self, p: torch.Tensor) -> torch.Tensor:
    r"""Numerator function f(x) = p * x^q."""
    return p.pow(self.q).mul_(self.p)

AdamMini

Bases: BaseOptimizer

Use Fewer Learning Rates To Gain More.

Parameters:

Name Type Description Default
model Module

Model instance.

required
model_sharding bool

Set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1, zero_2, zero_3 in DeepSpeed. Set to False otherwise.

False
lr float

Learning rate.

1.0
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.1
num_embeds int

Number of embedding dimensions. Could be unspecified if training non-transformer models.

2048
num_heads int

Number of attention heads. Could be unspecified if training non-transformer models.

32
num_query_groups Optional[int]

Number of query groups in Group Query Attention (GQA). If not specified, defaults to num_heads. Could be unspecified for non-transformer models.

None
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adam_mini.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
class AdamMini(BaseOptimizer):  # pragma: no cover
    """Use Fewer Learning Rates To Gain More.

    Args:
        model (nn.Module): Model instance.
        model_sharding (bool): Set to True if you are using model parallelism with more than 1 GPU, including FSDP
            and zero_1, zero_2, zero_3 in DeepSpeed. Set to False otherwise.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        num_embeds (int): Number of embedding dimensions. Could be unspecified if training non-transformer models.
        num_heads (int): Number of attention heads. Could be unspecified if training non-transformer models.
        num_query_groups (Optional[int]): Number of query groups in Group Query Attention (GQA).
            If not specified, defaults to num_heads. Could be unspecified for non-transformer models.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1.0,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.1,
        model_sharding: bool = False,
        num_embeds: int = 2048,
        num_heads: int = 32,
        num_query_groups: Optional[int] = None,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(num_embeds, 'num_embeds')
        self.validate_non_negative(num_heads, 'num_heads')
        self.validate_non_negative(eps, 'eps')

        self.num_query_groups: int = num_query_groups if num_query_groups is not None else num_embeds
        self.validate_mod(num_embeds, self.num_query_groups)

        self.world_size: int = torch.cuda.device_count()

        self.model = model
        self.model_sharding = model_sharding
        self.num_embeds = num_embeds
        self.num_heads = num_heads

        self.embed_blocks: Set[str] = {'embed', 'embd', 'wte', 'lm_head.weight', 'output.weight'}
        self.qk_blocks: Set[str] = {'k_proj.weight', 'q_proj.weight', 'wq.weight', 'wk.weight'}

        self.maximize = maximize

        groups = self.get_optimizer_groups(weight_decay)

        defaults: Defaults = {'lr': lr, 'betas': betas, 'eps': eps, **kwargs}

        super().__init__(groups, defaults)

    def __str__(self) -> str:
        return 'AdamMini'

    def get_optimizer_groups(self, weight_decay: float):
        groups = []
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue

            group = {
                'name': name,
                'params': param,
                'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay,
            }

            if any(block in name for block in self.qk_blocks):
                group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads

            if 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
                group['n_head'] = self.num_heads
                group['q_per_kv'] = self.num_embeds // self.num_query_groups

            groups.append(group)

        return groups

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @staticmethod
    def step_embed(
        p,
        grad,
        state,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32)
            state['v'] = torch.zeros_like(p, dtype=torch.float32)

        m, v = state['m'], state['v']

        m.lerp_(grad, weight=1.0 - beta1)
        v.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        p.addcdiv_(m, h, value=-lr / bias_correction1)

    @staticmethod
    def step_attn_proj(
        p,
        grad,
        state,
        parameter_per_head: int,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32).view(-1, parameter_per_head)
            state['head'] = state['m'].shape[0]
            state['v_mean'] = torch.zeros(state['head'], device=state['m'].device)

        m, v = state['m'], state['v_mean']

        head: int = state['head']
        grad = grad.view(head, parameter_per_head)

        m.lerp_(grad, weight=1.0 - beta1)

        tmp_lr = torch.mean(grad * grad, dim=1).to(m.device)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        update = (1 / (h * bias_correction1)).view(head, 1).mul_(m)

        if p.dim() > 1:
            d0, d1 = p.size()
            update = update.view(d0, d1)
        else:
            update = update.view(-1)

        p.add_(update, alpha=-lr)

    @staticmethod
    def step_attn(
        p,
        grad,
        state,
        num_heads: int,
        q_per_kv: int,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32).view(num_heads, q_per_kv + 2, -1)
            state['v_mean'] = torch.zeros(num_heads, q_per_kv + 2, device=state['m'].device)

        m, v = state['m'], state['v_mean']

        grad = grad.view(num_heads, q_per_kv + 2, -1)

        m.lerp_(grad, weight=1.0 - beta1)

        tmp_lr = torch.mean(grad * grad, dim=2).to(m.device)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        update = (1 / (h * bias_correction1)).view(num_heads, q_per_kv + 2, -1).mul_(m)

        if p.dim() > 1:
            d0, d1 = p.size()
            update = update.view(d0, d1)
        else:
            update = update.view(-1)

        p.add_(update, alpha=-lr)

    def step_lefts(
        self,
        p,
        grad,
        state,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32)

            reduced: bool = False
            if self.model_sharding and self.world_size > 1:
                tensor_list = [torch.zeros_like(dim) for _ in range(self.world_size)]
                dist.all_gather(tensor_list, dim)

                s, dim = 0, 0
                for d in tensor_list:
                    if d > 0:
                        s += 1
                    dim += d

                if s >= 2:
                    reduced = True

            state['m'] = torch.zeros_like(p, dtype=torch.float32)
            state['v_mean'] = torch.tensor(0.0, device=state['m'].device)
            state['dimension'] = dim
            state['reduced'] = reduced

        tmp_lr = torch.sum(grad * grad)

        if state['reduced']:
            dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)

        tmp_lr.div_(state['dimension'])

        m, v = state['m'], state['v_mean']

        m.lerp_(grad, weight=1.0 - beta1)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        stepsize = (1 / bias_correction1) / h

        update = m * stepsize

        p.add_(update, alpha=-lr)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            name = group['name']

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            bias_correction2_sq: float = math.sqrt(bias_correction2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                grad = grad.to(torch.float32)

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                if any(block in name for block in self.embed_blocks):
                    self.step_embed(
                        p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps']
                    )
                elif any(block in name for block in self.qk_blocks):
                    self.step_attn_proj(
                        p,
                        grad,
                        state,
                        group['parameter_per_head'],
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )
                elif 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
                    self.step_attn(
                        p,
                        grad,
                        state,
                        group['n_head'],
                        group['q_per_kv'],
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )
                else:
                    self.step_lefts(
                        p,
                        grad,
                        state,
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )

        return loss

AdaMod

Bases: BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace. beta3 is the smoothing coefficient for adaptive learning rates.

(0.9, 0.99, 0.9999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamod.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class AdaMod(BaseOptimizer):
    """An Adaptive and Momental Bound Method for Stochastic Learning.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
            beta3 is the smoothing coefficient for adaptive learning rates.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.99, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMod'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_lr'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_avg_lr = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_lr']

                p, grad, exp_avg, exp_avg_sq, exp_avg_lr = self.view_as_real(p, grad, exp_avg, exp_avg_sq, exp_avg_lr)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                update = torch.full_like(de_nom, fill_value=step_size)
                update.div_(de_nom)

                exp_avg_lr.mul_(beta3).add_(update, alpha=1.0 - beta3)

                torch.min(update, exp_avg_lr, out=update)
                update.mul_(exp_avg)

                p.add_(-update)

        return loss

AdamP

Bases: BaseOptimizer

Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
delta float

Threshold that determines whether a set of parameters is scale-invariant or not.

0.1
wd_ratio float

Relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
nesterov bool

Enables Nesterov momentum.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamp.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
class AdamP(BaseOptimizer):
    """Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        delta (float): Threshold that determines whether a set of parameters is scale-invariant or not.
        wd_ratio (float): Relative weight decay applied on scale-invariant parameters compared to that applied
            on scale-variant parameters.
        nesterov (bool): Enables Nesterov momentum.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        nesterov: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamP'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                inv_de_nom = exp_avg_sq.sqrt().add_(group['eps']).reciprocal_().mul_(bias_correction2_sq)

                perturb = exp_avg.clone()

                if group.get('cautious'):
                    self.apply_cautious(perturb, grad)

                if group['nesterov']:
                    perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
                else:
                    perturb.mul_(inv_de_nom)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    perturb, wd_ratio = projection(
                        p,
                        grad,
                        perturb,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio,
                )

                p.add_(perturb, alpha=-step_size)

        return loss

AdamS

Bases: BaseOptimizer

Adam with stable weight decay.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of the gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0001
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
ams_bound bool

Whether to use the AMSBound variant of Adam.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adams.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class AdamS(BaseOptimizer):
    """Adam with stable weight decay.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of the gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        ams_bound (bool): Whether to use the AMSBound variant of Adam.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if group['ams_bound']:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    exp_avg_sq_hat = max_exp_avg_sq
                else:
                    exp_avg_sq_hat = exp_avg_sq

                exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError

        exp_avg_sq_hat_mean: float = math.sqrt(exp_avg_sq_hat_sum / param_size) + self.defaults['eps']

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / exp_avg_sq_hat_mean,
                )

                exp_avg_sq_hat = state['max_exp_avg_sq'] if group['ams_bound'] else state['exp_avg_sq']
                exp_avg_sq_hat.div_(bias_correction2)

                de_nom = exp_avg_sq_hat.sqrt().add_(group['eps'])

                p.addcdiv_(state['exp_avg'], de_nom, value=-step_size)

        return loss

AdaMuon

Bases: BaseOptimizer

Adaptive Muon optimizer.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.

Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

Parameters:

Name Type Description Default
params ParamsT

The parameters to be optimized by Muon.

required
lr float

Learning rate.

0.02
betas tuple

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.95)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
ns_steps int

The number of Newton-Schulz iterations to run. (5 is probably always enough)

5
ns_coeffs NewtonSchulzWeights

Newton-Schulz coefficients or preset name.

'original'
use_adjusted_lr bool

Whether to use adjusted learning rate, which is from the Moonlight. Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

False
adamw_lr float

The learning rate for the internal AdamW.

0.0003
adamw_betas tuple

The betas for the internal AdamW.

(0.9, 0.999)
adamw_wd float

The weight decay for the internal AdamW.

0.0
eps float

Term added to the denominator to improve numerical stability.

1e-10
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Example

from pytorch_optimizer import AdaMuon

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2] hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2] non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

param_groups = [ dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True), dict( params=hidden_gains_biases + non_hidden_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01, use_muon=False, ), ]

optimizer = AdaMuon(param_groups)

Source code in pytorch_optimizer/optimizer/muon.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
class AdaMuon(BaseOptimizer):
    """Adaptive Muon optimizer.

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
    each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
    update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

    Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
    scalar or vector parameters should be optimized using AdamW.

    Some warnings:
    - We believe this optimizer is unlikely to work well for training with small batch size.
    - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

    Args:
        params (ParamsT): The parameters to be optimized by Muon.
        lr (float): Learning rate.
        betas (tuple): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        ns_steps (int): The number of Newton-Schulz iterations to run. (5 is probably always enough)
        ns_coeffs (NewtonSchulzWeights): Newton-Schulz coefficients or preset name.
        use_adjusted_lr (bool): Whether to use adjusted learning rate, which is from the Moonlight.
            Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
        adamw_lr (float): The learning rate for the internal AdamW.
        adamw_betas (tuple): The betas for the internal AdamW.
        adamw_wd (float): The weight decay for the internal AdamW.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    Example:
        from pytorch_optimizer import AdaMuon

        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

        param_groups = [
            dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True),
            dict(
                params=hidden_gains_biases + non_hidden_params,
                lr=3e-4,
                betas=(0.9, 0.95),
                weight_decay=0.01,
                use_muon=False,
            ),
        ]

        optimizer = AdaMuon(param_groups)

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2e-2,
        betas: Betas = (0.9, 0.95),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        ns_steps: int = 5,
        ns_coeffs: NewtonSchulzWeights = 'original',
        use_adjusted_lr: bool = False,
        adamw_lr: float = 3e-4,
        adamw_betas: Betas = (0.9, 0.999),
        adamw_wd: float = 0.0,
        eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(adamw_lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(ns_steps, 'ns_steps')
        self.validate_betas(betas)
        self.validate_betas(adamw_betas)
        self.validate_non_negative(adamw_wd, 'adamw_wd')
        self.validate_non_negative(eps, 'eps')
        ns_coeffs = get_newton_schulz_weights(ns_coeffs)

        self.maximize = maximize

        for group in params:
            group = cast(ParamGroup, group)
            if 'use_muon' not in group:
                raise ValueError('`use_muon` must be set.')

            if group['use_muon']:
                group['lr'] = group.get('lr', lr)
                group['betas'] = group.get('betas', betas)
                group['weight_decay'] = group.get('weight_decay', weight_decay)
                group['ns_steps'] = group.get('ns_steps', ns_steps)
                group['ns_coeffs'] = get_newton_schulz_weights(group.get('ns_coeffs', ns_coeffs))
                group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr)
            else:
                group['lr'] = group.get('lr', adamw_lr)
                group['betas'] = group.get('betas', adamw_betas)
                group['weight_decay'] = group.get('weight_decay', adamw_wd)

            group['weight_decouple'] = group.get('weight_decouple', weight_decouple)
            group['eps'] = group.get('eps', eps)

        super().__init__(params, kwargs)

    def __str__(self) -> str:
        return 'AdaMuon'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['use_muon']:
                    state['m'] = torch.zeros_like(p)
                    state['v'] = torch.zeros_like(p.flatten())
                else:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                if group['use_muon']:
                    m = state['m']
                    m.lerp_(grad, weight=1.0 - beta1)

                    update = m.clone()

                    if update.ndim > 2:
                        update = update.view(len(update), -1)

                    update = zero_power_via_newton_schulz_5(
                        update, num_steps=group['ns_steps'], weights=group['ns_coeffs']
                    ).flatten()

                    v = state['v']
                    v.mul_(beta2).addcmul_(update, update, value=1.0 - beta2)

                    update.div_((v / bias_correction2).sqrt_().add_(group['eps']))
                    update = update.reshape(p.size())

                    update.mul_(0.2 * math.sqrt(p.numel())).div_(update.norm().add_(group['eps']))

                    lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])

                    p.add_(update, alpha=-lr)
                else:
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                    exp_avg.lerp_(grad, weight=1.0 - beta1)
                    exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2)

                    de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(math.sqrt(bias_correction2))

                    p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

AdamWSN

Bases: BaseOptimizer

Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
subset_size int

If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B and so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1 to use the recommended subset size as stated above.

-1
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Example

sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)] sn_param_ids = [id(p) for p in sn_params] regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids] param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}] optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size)

Source code in pytorch_optimizer/optimizer/snsm.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class AdamWSN(BaseOptimizer):
    """Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        subset_size (int): If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where
            d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B
            and so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1
            to use the recommended subset size as stated above.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    Example:
        >>> sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)]
        >>> sn_param_ids = [id(p) for p in sn_params]
        >>> regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids]
        >>> param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}]
        >>> optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size)

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        subset_size: int = -1,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'subset_size': subset_size,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamWSN'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)

                if group.get('sn'):
                    size: int = grad.numel()

                    if 'subset_size' not in state:
                        state['subset_size'] = closest_smaller_divisor_of_n_to_k(
                            size,
                            (
                                group['subset_size']
                                if group['subset_size'] > 0
                                else int(math.sqrt(size) / abs(int(group['subset_size'])))
                            ),
                        )

                    reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
                    second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True)  # fmt: skip
                    state['exp_avg_sq'] = torch.zeros_like(second_moment_update)
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                size = grad.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                if group.get('sn'):
                    reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
                    second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True)  # fmt: skip
                else:
                    second_moment_update = grad.pow(2)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).add_(second_moment_update, alpha=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if group.get('sn'):
                    numerator = exp_avg.view(size // state['subset_size'], state['subset_size'])
                    norm_grad = (numerator / de_nom).reshape(p.shape)
                    p.add_(norm_grad, alpha=-step_size)
                else:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

Adan

Bases: BaseOptimizer

Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Decoupled weight decay.

False
max_grad_norm float

Maximum gradient norm to clip.

0.0
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adan.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class Adan(BaseOptimizer):
    """Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Decoupled weight decay.
        max_grad_norm (float): Maximum gradient norm to clip.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        max_grad_norm: float = 0.0,
        foreach: Optional[bool] = None,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.max_grad_norm = max_grad_norm
        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'max_grad_norm': max_grad_norm,
            'foreach': foreach,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adan'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        clip_global_grad_norm: float = kwargs.get('clip_global_grad_norm', 0.0)

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)
                state['previous_grad'] = grad.clone().mul_(-clip_global_grad_norm)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        if group.get('use_gc') or group.get('adanorm'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        exp_avg_diffs: List[torch.Tensor],
        prev_grads: List[torch.Tensor],
        clip_global_grad_norm: Union[torch.Tensor, float],
    ) -> None:
        beta1, beta2, beta3 = group['betas']
        lr = group['lr']

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2: float = self.debias(beta2, group['step'])
        bias_correction3_sq: float = math.sqrt(self.debias(beta3, group['step']))

        if self.maximize:
            torch._foreach_neg_(grads)

        if isinstance(clip_global_grad_norm, torch.Tensor):
            clip_global_grad_norm = clip_global_grad_norm.item()

        torch._foreach_mul_(grads, clip_global_grad_norm)

        grad_diffs = torch._foreach_add(prev_grads, grads)

        torch._foreach_mul_(exp_avgs, beta1)
        torch._foreach_add_(exp_avgs, grads, alpha=1.0 - beta1)

        torch._foreach_mul_(exp_avg_diffs, beta2)
        torch._foreach_add_(exp_avg_diffs, grad_diffs, alpha=1.0 - beta2)

        torch._foreach_mul_(grad_diffs, beta2)
        torch._foreach_add_(grad_diffs, grads)

        torch._foreach_mul_(exp_avg_sqs, beta3)
        torch._foreach_addcmul_(exp_avg_sqs, grad_diffs, grad_diffs, value=1.0 - beta3)

        de_noms = torch._foreach_sqrt(exp_avg_sqs)
        torch._foreach_div_(de_noms, bias_correction3_sq)
        torch._foreach_add_(de_noms, group['eps'])

        if group['weight_decouple']:
            torch._foreach_mul_(params, 1.0 - lr * group['weight_decay'])

        torch._foreach_addcdiv_(params, exp_avgs, de_noms, value=-lr / bias_correction1)
        torch._foreach_addcdiv_(params, exp_avg_diffs, de_noms, value=-lr * beta2 / bias_correction2)

        if not group['weight_decouple']:
            torch._foreach_div_(params, 1.0 + lr * group['weight_decay'])

        torch._foreach_copy_(prev_grads, torch._foreach_neg(grads))

    def _step_per_param(self, group: ParamGroup, clip_global_grad_norm: Union[torch.Tensor, float]) -> None:
        beta1, beta2, beta3 = group['betas']

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2: float = self.debias(beta2, group['step'])
        bias_correction3_sq: float = math.sqrt(self.debias(beta3, group['step']))

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']
            grad_diff = state['previous_grad']

            p, grad, exp_avg, exp_avg_sq, exp_avg_diff, grad_diff = self.view_as_real(
                p, grad, exp_avg, exp_avg_sq, exp_avg_diff, grad_diff
            )

            grad.mul_(clip_global_grad_norm)

            if group.get('use_gc'):
                centralize_gradient(grad, gc_conv_only=False)

            grad_diff.add_(grad)

            s_grad = self.get_adanorm_gradient(
                grad=grad,
                adanorm=group.get('adanorm', False),
                exp_grad_norm=state.get('exp_grad_adanorm', None),
                r=group.get('adanorm_r', None),
            )

            exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
            exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)

            grad_diff.mul_(beta2).add_(grad)
            exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

            de_nom = exp_avg_sq.sqrt().div_(bias_correction3_sq).add_(group['eps'])

            if group['weight_decouple']:
                p.mul_(1.0 - group['lr'] * group['weight_decay'])

            p.addcdiv_(exp_avg, de_nom, value=-group['lr'] / bias_correction1)
            p.addcdiv_(exp_avg_diff, de_nom, value=-group['lr'] * beta2 / bias_correction2)

            if not group['weight_decouple']:
                p.div_(1.0 + group['lr'] * group['weight_decay'])

            grad.neg_()
            state['previous_grad'].copy_(
                torch.view_as_complex(grad) if torch.is_complex(state['previous_grad']) else grad
            )

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        clip_global_grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            self.init_group(group, clip_global_grad_norm=clip_global_grad_norm)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_sq', 'exp_avg_diff', 'previous_grad']
                )
                if params:
                    self._step_foreach(
                        group,
                        params,
                        grads,
                        state_dict['exp_avg'],
                        state_dict['exp_avg_sq'],
                        state_dict['exp_avg_diff'],
                        state_dict['previous_grad'],
                        clip_global_grad_norm,
                    )
            else:
                self._step_per_param(group, clip_global_grad_norm)

        return loss

AdaNorm

Bases: BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.99)
r float

EMA factor. Preferred values are between 0.9 and 0.99.

0.95
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adanorm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class AdaNorm(BaseOptimizer):
    """Symbolic Discovery of Optimization Algorithms.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        r (float): EMA factor. Preferred values are between 0.9 and 0.99.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.99),
        r: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'r': r,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaNorm'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)
                state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=True,
                    exp_grad_norm=state['exp_grad_norm'],
                    r=group['r'],
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_var.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

AdaPNM

Bases: BaseOptimizer

Adam + Positive-Negative Momentum Optimizers.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999, 1.0)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Use decoupled weight decay.

True
fixed_decay bool

Fix weight decay.

False
ams_bound bool

Whether to use the AMSBound variant.

True
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adapnm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class AdaPNM(BaseOptimizer):
    """Adam + Positive-Negative Momentum Optimizers.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Use decoupled weight decay.
        fixed_decay (bool): Fix weight decay.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaPNM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['neg_exp_avg'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            beta1_p2: float = beta1 ** 2  # fmt: skip
            noise_norm: float = math.sqrt((1 + beta3) ** 2 + beta3 ** 2)  # fmt: skip

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False), step_size=group['lr'], bias_correction1=bias_correction1
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg_sq = state['exp_avg_sq']

                if group['step'] % 2 == 1:
                    exp_avg, neg_exp_avg = state['exp_avg'], state['neg_exp_avg']
                else:
                    exp_avg, neg_exp_avg = state['neg_exp_avg'], state['exp_avg']

                p, grad, exp_avg, neg_exp_avg, exp_avg_sq = self.view_as_real(
                    p, grad, exp_avg, neg_exp_avg, exp_avg_sq
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1_p2).add_(s_grad, alpha=1.0 - beta1_p2)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                pn_momentum = exp_avg.mul(1.0 + beta3).add_(neg_exp_avg, alpha=-beta3).mul_(1.0 / noise_norm)

                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        return loss

AdaShift

Bases: BaseOptimizer

Decorrelation and Convergence of Adaptive Learning Rate Methods.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
keep_num int

Number of gradients used to compute first moment estimation.

10
reduce_func Optional[Callable]

Function applied to squared gradients to reduce correlation. If None, no function is applied.

max
eps float

Term added to the denominator to improve numerical stability.

1e-10
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adashift.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class AdaShift(BaseOptimizer):
    """Decorrelation and Convergence of Adaptive Learning Rate Methods.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        keep_num (int): Number of gradients used to compute first moment estimation.
        reduce_func (Optional[Callable]): Function applied to squared gradients to reduce correlation.
            If None, no function is applied.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        keep_num: int = 10,
        reduce_func: Optional[Callable] = torch.max,
        eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(keep_num, 'keep_num')
        self.validate_non_negative(eps, 'eps')

        self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x
        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaShift'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['grad_queue'] = deque([grad.clone()], maxlen=group['keep_num'])
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            exp_weight_sum: int = sum(beta1 ** i for i in range(group['keep_num']))  # fmt: skip
            first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum
            last_grad_weight: float = 1.0 / exp_weight_sum

            bias_correction: float = self.debias(beta2, group['step'] - group['keep_num'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_queue = state['grad_queue']
                grad_queue.append(grad.clone())

                if len(grad_queue) != group['keep_num']:
                    continue

                offset_grad = grad_queue[0]

                exp_avg = state['exp_avg']
                exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight)

                reduced_grad_sq = self.reduce_func(offset_grad.pow_(2))

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

                update = exp_avg.clone()
                if group.get('cautious'):
                    self.apply_cautious(update, grad)

                update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))

                p.add_(update, alpha=-group['lr'])

        return loss

AdaSmooth

Bases: BaseOptimizer

An Adaptive Learning Rate Method based on Effective Ratio.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.5, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adasmooth.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class AdaSmooth(BaseOptimizer):
    """An Adaptive Learning Rate Method based on Effective Ratio.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.5, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaSmooth'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['prev_param'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)
                state['n'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s, n, prev_param, exp_avg_sq = state['s'], state['n'], state['prev_param'], state['exp_avg_sq']

                p, grad, s, n, prev_param, exp_avg_sq = self.view_as_real(p, grad, s, n, prev_param, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p_diff = p - prev_param

                s.add_(p_diff)
                n.add_(p_diff.abs())

                c = s.sum().abs_().div_(n.sum())  # e_t
                c.mul_(beta2 - beta1).add_(1.0 - beta2)

                c_p2 = c.pow(2)

                exp_avg_sq.mul_(1.0 - c_p2).addcmul_(grad, grad, value=c_p2)

                step_size = torch.full_like(exp_avg_sq, fill_value=group['lr'])
                step_size.div_((exp_avg_sq + group['eps']).sqrt()).mul_(grad)

                p.add_(-step_size)

                state['prev_param'].copy_(torch.view_as_complex(p) if torch.is_complex(state['prev_param']) else p)

        return loss

AdaTAM

Bases: BaseOptimizer

Adaptive Torque-Aware Momentum.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups. :param lr: float. learning rate. :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. :parma decay_rate: float. smoothing decay rate. :param weight_decay: float. weight decay (L2 penalty). :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW. :param fixed_decay: bool. fix weight decay. :param eps: float. term added to the denominator to improve numerical stability. :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.

Source code in pytorch_optimizer/optimizer/tam.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class AdaTAM(BaseOptimizer):
    r"""Adaptive Torque-Aware Momentum.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :parma decay_rate: float. smoothing decay rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        decay_rate: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(decay_rate, 'decay_rate', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'decay_rate': decay_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaTAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(grad)
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']
            decay_rate: float = group['decay_rate']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad,
                    group['lr'],
                    group['weight_decay'],
                    group['weight_decouple'],
                    group['fixed_decay'],
                )

                s, exp_avg, exp_avg_sq = state['s'], state['exp_avg'], state['exp_avg_sq']

                corr = normalize(exp_avg, p=2.0, dim=0).mul_(normalize(grad, p=2.0, dim=0))
                s.mul_(decay_rate).add_(corr, alpha=1.0 - decay_rate)

                d = ((1.0 + s) / 2.0).add_(group['eps']).mul_(grad)

                exp_avg.mul_(beta1).add_(d)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-group['lr'])

        return loss

AdEMAMix

Bases: BaseOptimizer

Better, Faster, Older.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999, 0.9999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
alpha float

Usually between 4 and 10 would work well.

5.0
t_alpha_beta3 Optional[float]

Total number of iterations preferred when needed.

None
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ademamix.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class AdEMAMix(BaseOptimizer):
    """Better, Faster, Older.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        alpha (float): Usually between 4 and 10 would work well.
        t_alpha_beta3 (Optional[float]): Total number of iterations preferred when needed.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        alpha: float = 5.0,
        t_alpha_beta3: Optional[float] = None,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(alpha, 'alpha')
        self.validate_non_negative(t_alpha_beta3, 't_alpha_beta3')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'alpha': alpha,
            't_alpha_beta3': t_alpha_beta3,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdEMAMix'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_slow'] = torch.zeros_like(p)

    @staticmethod
    def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
        return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)

    @staticmethod
    def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
        if t_alpha_beta3 is None:
            return beta3

        log_beta1, log_beta3 = math.log(beta1), math.log(beta3)

        return min(
            math.exp(
                log_beta1 * log_beta3 / ((1.0 - step / t_alpha_beta3) * log_beta3 + (step / t_alpha_beta3) * log_beta1)
            ),
            beta3,
        )

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] / bias_correction1

            alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
            beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                update = exp_avg.clone()
                if group.get('cautious'):
                    self.apply_cautious(update, grad)

                if group.get('stable_adamw'):
                    step_size /= self.get_stable_adamw_rms(grad, exp_avg_sq)

                update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)

                p.add_(update, alpha=-step_size)

        return loss

ADOPT

Bases: BaseOptimizer

Modified Adam Can Converge with Any β2 with the Optimal Rate.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.9999)
clip_lambda Callable[[float], float]

Function to clip gradient. Default is step ** 0.25.

lambda step: pow(step, 0.25)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adopt.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class ADOPT(BaseOptimizer):
    """Modified Adam Can Converge with Any β2 with the Optimal Rate.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        clip_lambda (Callable[[float], float]): Function to clip gradient. Default is `step ** 0.25`.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.9999),
        clip_lambda: Callable[[float], float] = lambda step: math.pow(step, 0.25),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        foreach: Optional[bool] = None,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.clip_lambda = clip_lambda
        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'foreach': foreach,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ADOPT'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        if group.get('cautious') or group.get('stable_adamw'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
    ) -> None:
        beta1, beta2 = group['betas']
        lr = group['lr']
        eps = group['eps']

        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=lr,
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        if group['step'] == 1:
            torch._foreach_addcmul_(exp_avg_sqs, grads, grads)
            return

        de_noms = torch._foreach_sqrt(exp_avg_sqs)
        torch._foreach_clamp_min_(de_noms, eps)

        normed_grads = torch._foreach_div(grads, de_noms)
        if self.clip_lambda is not None:
            clip: float = self.clip_lambda(group['step'])
            torch._foreach_clamp_min_(normed_grads, -clip)
            torch._foreach_clamp_max_(normed_grads, clip)

        torch._foreach_lerp_(exp_avgs, normed_grads, weight=1.0 - beta1)

        torch._foreach_add_(params, exp_avgs, alpha=-lr)

        torch._foreach_mul_(exp_avg_sqs, beta2)
        torch._foreach_addcmul_(exp_avg_sqs, grads, grads, value=1.0 - beta2)

    def _step_per_param(self, group: ParamGroup) -> None:
        beta1, beta2 = group['betas']
        lr: float = group['lr']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

            p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

            self.apply_weight_decay(
                p=p,
                grad=grad,
                lr=lr,
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            if group['step'] == 1:
                exp_avg_sq.addcmul_(grad, grad.conj())
                continue

            de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

            normed_grad = grad.div(de_nom)
            if self.clip_lambda is not None:
                clip = self.clip_lambda(group['step'])
                normed_grad.clamp_(-clip, clip)

            exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

            if group.get('cautious'):
                update = exp_avg.clone()
                self.apply_cautious(update, normed_grad)
            else:
                update = exp_avg

            step_lr = lr
            if group.get('stable_adamw'):
                step_lr /= self.get_stable_adamw_rms(grad, exp_avg_sq)

            p.add_(update, alpha=-step_lr)

            exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_sq']
                )
                if params:
                    self._step_foreach(group, params, grads, state_dict['exp_avg'], state_dict['exp_avg_sq'])
            else:
                self._step_per_param(group)

        return loss

agc

agc(p, grad, agc_eps=0.001, agc_clip_val=0.01, eps=1e-06)

Clip gradient values in excess of the unit-wise norm.

Parameters:

Name Type Description Default
p Tensor

Parameter tensor.

required
grad Tensor

Gradient tensor.

required
agc_eps float

AGC epsilon to clip the norm of the parameter.

0.001
agc_clip_val float

Norm clip value.

0.01
eps float

Small term to prevent division by zero, unrelated to standard optimizer eps.

1e-06
Source code in pytorch_optimizer/optimizer/agc.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def agc(
    p: torch.Tensor, grad: torch.Tensor, agc_eps: float = 1e-3, agc_clip_val: float = 1e-2, eps: float = 1e-6
) -> torch.Tensor:
    """Clip gradient values in excess of the unit-wise norm.

    Args:
        p (torch.Tensor): Parameter tensor.
        grad (torch.Tensor): Gradient tensor.
        agc_eps (float): AGC epsilon to clip the norm of the parameter.
        agc_clip_val (float): Norm clip value.
        eps (float): Small term to prevent division by zero, unrelated to standard optimizer eps.

    """
    max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val)
    g_norm = unit_norm(grad).clamp_min_(eps)

    clipped_grad = grad * (max_norm / g_norm)

    return torch.where(g_norm > max_norm, clipped_grad, grad)

AggMo

Bases: BaseOptimizer

Aggregated Momentum: Stability Through Passive Damping.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.0, 0.9, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/aggmo.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class AggMo(BaseOptimizer):
    """Aggregated Momentum: Stability Through Passive Damping.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.0, 0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AggMo'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']}

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            betas = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                for beta in betas:
                    buf = state['momentum_buffer'][beta]
                    buf.mul_(beta).add_(grad)

                    p.add_(buf, alpha=-group['lr'] / len(betas))

        return loss

Aida

Bases: BaseOptimizer

A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
k int

Number of vectors projected per iteration.

2
xi float

Term used in vector projections to avoid division by zero.

1e-20
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
rectify bool

Perform the rectified update similar to RAdam.

False
n_sma_threshold int

Number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

Perform SGD update when variance of gradient is high.

True
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/aida.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class Aida(BaseOptimizer):
    """A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        k (int): Number of vectors projected per iteration.
        xi (float): Term used in vector projections to avoid division by zero.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        rectify (bool): Perform the rectified update similar to RAdam.
        n_sma_threshold (int): Number of SMA threshold (recommended is 5).
        degenerated_to_sgd (bool): Perform SGD update when variance of gradient is high.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        k: int = 2,
        xi: float = 1e-20,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(eps, 'eps')

        self.k = k
        self.xi = xi
        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Aida'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                proj_g = grad.detach().clone()
                proj_m = exp_avg.detach().clone()

                for _ in range(self.k):
                    proj_sum_gm = torch.sum(torch.mul(proj_g, proj_m))

                    scalar_g = proj_sum_gm / (torch.sum(torch.pow(proj_g, 2)).add_(self.xi))
                    scalar_m = proj_sum_gm / (torch.sum(torch.pow(proj_m, 2)).add_(self.xi))

                    proj_g.mul_(scalar_g)
                    proj_m.mul_(scalar_m)

                grad_residual = proj_m - proj_g
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )

                if not group['rectify']:
                    de_nom.div_(bias_correction2_sq)
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

Alice

Bases: BaseOptimizer

Adaptive low-dimensional subspace estimation.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.02
betas Betas

coefficients used for computing running averages of gradient and the squared Hessian trace. beta3=0 for Alice-0 optimizer.

(0.9, 0.9, 0.999)
alpha float

scaler.

0.3
alpha_c float

compensation scaler.

0.4
update_interval int

update interval.

200
rank int

rank.

256
gamma float

limiter threshold.

1.01
leading_basis int

leading basis.

40
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
eps float

term added to the denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/racs.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class Alice(BaseOptimizer):
    """Adaptive low-dimensional subspace estimation.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared Hessian trace.
            beta3=0 for Alice-0 optimizer.
        alpha (float): scaler.
        alpha_c (float): compensation scaler.
        update_interval (int): update interval.
        rank (int): rank.
        gamma (float): limiter threshold.
        leading_basis (int): leading basis.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 0.02,
        betas: Betas = (0.9, 0.9, 0.999),
        alpha: float = 0.3,
        alpha_c: float = 0.4,
        update_interval: int = 200,
        rank: int = 256,
        gamma: float = 1.01,
        leading_basis: int = 40,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_range(alpha_c, 'alpha_c', 0.0, 1.0)
        self.validate_positive(update_interval, 'update_interval')
        self.validate_positive(rank, 'rank')
        self.validate_positive(gamma, 'gamma')
        self.validate_positive(leading_basis, 'leading_basis')
        self.validate_non_negative(rank - leading_basis, 'rank - leading_basis')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'alpha_c': alpha_c,
            'update_interval': update_interval,
            'rank': rank,
            'gamma': gamma,
            'leading_basis': leading_basis,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Alice'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @staticmethod
    def subspace_iteration(
        a: torch.Tensor, mat: torch.Tensor, num_steps: int = 1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Perform subspace iteration."""
        u = mat
        for _ in range(num_steps):
            u, _ = torch.linalg.qr(a @ u)

        return torch.linalg.eigh(u.T @ a @ u)

    def switch(self, q: torch.Tensor, u_prev: torch.Tensor, rank: int, leading_basis: int) -> torch.Tensor:
        vals, vecs = self.subspace_iteration(q.to(torch.float32), u_prev.to(torch.float32), num_steps=1)

        leading_indices = torch.argsort(vals, descending=True)[:leading_basis]
        u_t1 = vecs[:, leading_indices]

        u_c, _ = torch.linalg.qr(torch.eye(q.shape[0], device=q.device) - u_t1 @ u_t1.T)
        u_t2 = u_c[:, :rank - leading_basis]  # fmt: skip

        return torch.cat([u_t1, u_t2], dim=1).to(q.dtype)

    @staticmethod
    def compensation(
        grad: torch.Tensor,
        u: torch.Tensor,
        p: torch.Tensor,
        phi: torch.Tensor,
        gamma: float,
        decay_rate: float,
        rank: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        m, n = grad.shape

        sigma = u.T @ grad

        p.mul_(decay_rate).add_(grad.pow(2).sum(dim=0) - sigma.pow(2).sum(dim=0), alpha=1.0 - decay_rate).clamp_min_(
            1e-8
        )

        d = torch.zeros_like(grad)
        diag_len: int = min(m, n)
        d[torch.arange(diag_len), torch.arange(diag_len)] = 1.0 / p.sqrt()[:diag_len]

        c_t = math.sqrt(m - rank) * (grad - u @ sigma) * d if m >= rank else torch.zeros_like(grad)

        n = gamma / max(torch.norm(c_t) / phi, gamma) if phi.item() > 0 else torch.ones_like(phi)

        c_t.mul_(n)
        phi = torch.norm(c_t)

        return c_t, phi

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']
            rank, leading_basis = group['rank'], group['leading_basis']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if grad.ndim < 2:
                    grad = grad.reshape(len(grad), 1)
                elif grad.ndim > 2:
                    grad = grad.reshape(len(grad), -1)

                if len(state) == 0:
                    m, n = grad.shape

                    state['U'] = torch.zeros((m, rank), dtype=p.dtype, device=p.device)
                    state['Q'] = torch.zeros((rank, rank), dtype=p.dtype, device=p.device)

                    state['m'] = torch.zeros((rank, n), dtype=p.dtype, device=p.device)
                    state['v'] = torch.zeros((rank, n), dtype=p.dtype, device=p.device)

                    state['p'] = torch.zeros((n,), dtype=p.dtype, device=p.device)
                    state['phi'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                q, u, m, v = state['Q'], state['U'], state['m'], state['v']

                if group['step'] == 1 or group['step'] % group['update_interval'] == 0:
                    q_t = beta3 * (u @ q @ u.T) + (1.0 - beta3) * (grad @ grad.T)
                    u = self.switch(q_t, u, rank, leading_basis)

                sigma = u.T @ grad

                q.mul_(beta3).add_(sigma @ sigma.T, alpha=1.0 - beta3)
                m.mul_(beta1).add_(sigma, alpha=1.0 - beta1)
                v.mul_(beta2).add_(sigma.pow(2), alpha=1.0 - beta2)

                c_t, phi = self.compensation(grad, u, state['p'], state['phi'], group['gamma'], beta1, rank)

                update = u @ (m / v.sqrt())
                update.add_(c_t, alpha=group['alpha_c'])

                p.add_(update.view_as(p), alpha=-group['lr'] * group['alpha'])

                state['phi'] = phi

        return loss

subspace_iteration(a, mat, num_steps=1) staticmethod

Perform subspace iteration.

Source code in pytorch_optimizer/optimizer/racs.py
218
219
220
221
222
223
224
225
226
227
@staticmethod
def subspace_iteration(
    a: torch.Tensor, mat: torch.Tensor, num_steps: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Perform subspace iteration."""
    u = mat
    for _ in range(num_steps):
        u, _ = torch.linalg.qr(a @ u)

    return torch.linalg.eigh(u.T @ a @ u)

AliG

Bases: BaseOptimizer

Adaptive Learning Rates for Interpolation with Gradients.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
max_lr Optional[float]

Maximum learning rate.

None
projection_fn Callable

Projection function to enforce constraints.

None
momentum float

Momentum factor.

0.0
adjusted_momentum bool

If True, use PyTorch-like momentum instead of standard Nesterov momentum.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/alig.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class AliG(BaseOptimizer):
    """Adaptive Learning Rates for Interpolation with Gradients.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        max_lr (Optional[float]): Maximum learning rate.
        projection_fn (Callable): Projection function to enforce constraints.
        momentum (float): Momentum factor.
        adjusted_momentum (bool): If True, use PyTorch-like momentum instead of standard Nesterov momentum.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        max_lr: Optional[float] = None,
        projection_fn: Optional[Callable] = None,
        momentum: float = 0.0,
        adjusted_momentum: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(max_lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)

        self.projection_fn = projection_fn
        self.maximize = maximize

        defaults: Defaults = {'max_lr': max_lr, 'adjusted_momentum': adjusted_momentum, 'momentum': momentum}

        super().__init__(params, defaults)

        if self.projection_fn is not None:
            self.projection_fn()

    def __str__(self) -> str:
        return 'AliG'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        momentum: float = kwargs.get('momentum', 0.9)

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and momentum > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def compute_step_size(self, loss: float) -> float:
        r"""Compute step_size."""
        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.add_(1e-6)

        return loss / global_grad_norm.item()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        if closure is None:
            raise NoClosureError('AliG', '(e.g. `optimizer.step(lambda: float(loss))`).')

        loss = closure()

        un_clipped_step_size: float = self.compute_step_size(loss)

        for group in self.param_groups:
            momentum = group['momentum']

            self.init_group(group, momentum=momentum)
            group['step'] += 1

            step_size = group['step_size'] = (
                min(un_clipped_step_size, group['max_lr']) if group['max_lr'] is not None else un_clipped_step_size
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p, grad, buffer = self.view_as_real(p, grad, state.get('momentum_buffer', None))

                p.add_(grad, alpha=-step_size)

                if buffer is not None:
                    if group['adjusted_momentum']:
                        buffer.mul_(momentum).sub_(grad)
                        p.add_(buffer, alpha=step_size * momentum)
                    else:
                        buffer.mul_(momentum).add_(grad, alpha=-step_size)
                        p.add_(buffer, alpha=momentum)

            if self.projection_fn is not None:
                self.projection_fn()

        return loss

compute_step_size(loss)

Compute step_size.

Source code in pytorch_optimizer/optimizer/alig.py
79
80
81
82
83
84
85
@torch.no_grad()
def compute_step_size(self, loss: float) -> float:
    r"""Compute step_size."""
    global_grad_norm = get_global_gradient_norm(self.param_groups)
    global_grad_norm.add_(1e-6)

    return loss / global_grad_norm.item()

Amos

Bases: BaseOptimizer

An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
beta float

A float slightly less than 1. Recommended to set 1 - beta approximately the same magnitude as the learning rate, similar to beta2 in Adam.

0.999
momentum float

Exponential decay rate for optional moving average of updates.

0.0
extra_l2 float

Additional L2 regularization.

0.0
c_coef float

Coefficient for decay_factor_c.

0.25
d_coef float

Coefficient for decay_factor_d.

0.25
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
eps float

Term added to the denominator to improve numerical stability.

1e-18
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/amos.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
class Amos(BaseOptimizer):
    """An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): A float slightly less than 1. Recommended to set `1 - beta` approximately the same magnitude
            as the learning rate, similar to beta2 in Adam.
        momentum (float): Exponential decay rate for optional moving average of updates.
        extra_l2 (float): Additional L2 regularization.
        c_coef (float): Coefficient for decay_factor_c.
        d_coef (float): Coefficient for decay_factor_d.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        beta: float = 0.999,
        momentum: float = 0.0,
        extra_l2: float = 0.0,
        c_coef: float = 0.25,
        d_coef: float = 0.25,
        foreach: Optional[bool] = None,
        eps: float = 1e-18,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(extra_l2, 'extra_l2')
        self.validate_non_negative(eps, 'eps')

        self.c_coef = c_coef
        self.d_coef = d_coef
        self.foreach = foreach
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'beta': beta,
            'momentum': momentum,
            'extra_l2': extra_l2,
            'foreach': foreach,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Amos'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg_sq'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['decay'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['momentum'] > 0.0:
                    state['exp_avg'] = torch.zeros_like(p)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    @staticmethod
    def get_scale(p: torch.Tensor) -> float:
        r"""Get expected scale for model weights."""
        if len(p.shape) == 1:
            return 0.5
        if len(p.shape) == 2:
            return math.sqrt(2 / p.size(1))
        return math.sqrt(1 / p.size(1))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        decays: List[torch.Tensor],
    ) -> None:
        lr_sq: float = math.sqrt(group['lr'])
        lr_p2: float = math.pow(group['lr'], 2)

        beta: float = group['beta']
        bias_correction: float = self.debias(beta, group['step'])

        if self.maximize:
            torch._foreach_neg_(grads)

        g2 = [grad.pow(2).mean() for grad in grads]
        init_lrs: List[float] = [group['lr'] * self.get_scale(p) for p in params]

        torch._foreach_mul_(exp_avg_sqs, beta)
        torch._foreach_add_(exp_avg_sqs, g2, alpha=1.0 - beta)

        r_v_hat = torch._foreach_add(exp_avg_sqs, group['eps'])
        torch._foreach_reciprocal_(r_v_hat)
        torch._foreach_mul_(r_v_hat, bias_correction)

        df_c = torch._foreach_mul(decays, self.c_coef * lr_sq)
        torch._foreach_add_(df_c, 1.0)
        foreach_rsqrt_(df_c)

        d_step_sizes = [self.d_coef * math.sqrt(step_size) for step_size in init_lrs]
        df_d = torch._foreach_mul(decays, d_step_sizes)
        torch._foreach_add_(df_d, 1.0)

        torch._foreach_mul_(df_c, r_v_hat)
        torch._foreach_mul_(df_c, lr_p2)
        torch._foreach_mul_(df_c, g2)

        updates = torch._foreach_div(params, 2.0)
        torch._foreach_mul_(updates, torch._foreach_sub(df_c, group['extra_l2']))

        torch._foreach_sqrt_(r_v_hat)
        torch._foreach_mul_(r_v_hat, init_lrs)
        torch._foreach_add_(updates, torch._foreach_mul(grads, r_v_hat))

        torch._foreach_div_(updates, df_d)

        torch._foreach_mul_(decays, torch._foreach_add(df_c, 1.0))
        torch._foreach_add_(decays, df_c)

        if group['momentum'] > 0.0:
            torch._foreach_lerp_(exp_avgs, updates, weight=1.0 - group['momentum'])
            torch._foreach_copy_(updates, exp_avgs)

        torch._foreach_sub_(params, updates)

    def _step_per_param(self, group: ParamGroup) -> None:
        momentum, beta = group['momentum'], group['beta']

        lr_sq: float = math.sqrt(group['lr'])
        lr_p2: float = math.pow(group['lr'], 2)
        bias_correction: float = self.debias(beta, group['step'])

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            g2 = grad.pow(2).mean()
            init_lr: float = group['lr'] * self.get_scale(p)

            exp_avg_sq = state['exp_avg_sq']
            exp_avg_sq.mul_(beta).add_(g2, alpha=1.0 - beta)

            r_v_hat = bias_correction / (exp_avg_sq + group['eps'])

            decay = state['decay']
            decay_factor_c = torch.rsqrt(1.0 + self.c_coef * lr_sq * decay)
            decay_factor_d = torch.reciprocal(1.0 + self.d_coef * math.sqrt(init_lr) * decay)

            gamma = decay_factor_c * lr_p2 * r_v_hat * g2

            update = p.clone()
            update.mul_((gamma - group['extra_l2']) / 2.0)
            update.add_(r_v_hat.sqrt() * grad, alpha=init_lr)
            update.mul_(decay_factor_d)

            decay.mul_(1.0 + gamma).add_(gamma)

            if momentum > 0.0:
                exp_avg = state['exp_avg']
                exp_avg.mul_(momentum).add_(update, alpha=1.0 - momentum)

                update.copy_(exp_avg)

            p.add_(-update)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_sq', 'decay']
                )
                if params:
                    self._step_foreach(
                        group,
                        params,
                        grads,
                        state_dict['exp_avg'],
                        state_dict['exp_avg_sq'],
                        state_dict['decay'],
                    )
            else:
                self._step_per_param(group)

        return loss

get_scale(p) staticmethod

Get expected scale for model weights.

Source code in pytorch_optimizer/optimizer/amos.py
 96
 97
 98
 99
100
101
102
103
@staticmethod
def get_scale(p: torch.Tensor) -> float:
    r"""Get expected scale for model weights."""
    if len(p.shape) == 1:
        return 0.5
    if len(p.shape) == 2:
        return math.sqrt(2 / p.size(1))
    return math.sqrt(1 / p.size(1))

Ano

Bases: BaseOptimizer

Ano optimizer with adaptive momentum and sign-based updates.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0001
betas Betas

Coefficients used for computing running averages of gradient and the squared gradient.

(0.92, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
logarithmic_schedule bool

Enable adaptive beta1 scheduling based on step count.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ano.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class Ano(BaseOptimizer):
    r"""Ano optimizer with adaptive momentum and sign-based updates.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared gradient.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        logarithmic_schedule (bool): Enable adaptive beta1 scheduling based on step count.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-4,
        betas: Betas = (0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        logarithmic_schedule: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.logarithmic_schedule = logarithmic_schedule
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Ano'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            if self.logarithmic_schedule:
                max_t = max(2, group['step'])
                beta1 = 1.0 - 1.0 / math.log(max_t)

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                square_grad = grad.square()
                exp_avg_sq.mul_(beta2).addcmul_(torch.sign(square_grad - exp_avg_sq), square_grad, value=1.0 - beta2)

                de_nom = square_grad.copy_(exp_avg_sq).div_(bias_correction2).sqrt_().add_(group['eps'])

                p.addcdiv_(grad.abs().mul_(exp_avg.sign()), de_nom, value=-group['lr'])

        return loss

APOLLO

Bases: BaseOptimizer

SGD-like Memory, AdamW-level Performance.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
correct_bias bool

Whether to correct bias in Adam.

True
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/apollo.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class APOLLO(BaseOptimizer):
    """SGD-like Memory, AdamW-level Performance.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        correct_bias (bool): Whether to correct bias in Adam.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        betas: Betas = (0.9, 0.999),
        scale_type: SCALE_TYPE = 'tensor',
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        correct_bias: bool = True,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'scale_type': scale_type,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'correct_bias': correct_bias,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'APOLLO'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            step_size: float = group['lr']
            if group['correct_bias']:
                bias_correction1: float = self.debias(beta1, group['step'])
                bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
                step_size *= bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if 'rank' in group and p.dim() > 1:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=group['rank'],
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'], from_random_matrix=True)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom
                if 'rank' in group and p.dim() > 1:
                    if group['scale_type'] == 'channel':
                        norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
                        scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
                        if norm_dim == 1:
                            scaling_factor = scaling_factor.unsqueeze(1)
                    else:
                        scaling_factor = torch.norm(norm_grad) / (torch.norm(grad) + 1e-8)

                    scaling_grad = grad * scaling_factor

                    scaling_grad_norm = torch.norm(scaling_grad)
                    if 'scaling_grad' in state:
                        limiter = (
                            max(
                                scaling_grad_norm / (state['scaling_grad'] + 1e-8),
                                1.01,
                            )
                            / 1.01
                        )

                        scaling_grad.div_(limiter)
                        scaling_grad_norm.div_(limiter)

                    state['scaling_grad'] = scaling_grad_norm

                    norm_grad = scaling_grad * np.sqrt(group['scale'])
                    norm_grad = state['projector'].project_back(norm_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

ApolloDQN

Bases: BaseOptimizer

An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
init_lr Optional[float]

Initial learning rate (default lr / 1000).

1e-05
beta float

Coefficient used for computing running averages of gradient.

0.9
rebound str

Rectified bound for diagonal Hessian. Options: 'constant', 'belief'.

'constant'
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decay_type str

Type of weight decay. Options: 'l2', 'decoupled', 'stable'.

'l2'
warmup_steps int

Number of warmup steps.

500
eps float

Term added to the denominator to improve numerical stability.

0.0001
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/apollo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class ApolloDQN(BaseOptimizer):
    """An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        init_lr (Optional[float]): Initial learning rate (default lr / 1000).
        beta (float): Coefficient used for computing running averages of gradient.
        rebound (str): Rectified bound for diagonal Hessian. Options: 'constant', 'belief'.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decay_type (str): Type of weight decay. Options: 'l2', 'decoupled', 'stable'.
        warmup_steps (int): Number of warmup steps.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        init_lr: Optional[float] = 1e-5,
        beta: float = 0.9,
        rebound: str = 'constant',
        weight_decay: float = 0.0,
        weight_decay_type: str = 'l2',
        warmup_steps: int = 500,
        eps: float = 1e-4,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_options(rebound, 'rebound', ['constant', 'belief'])
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_options(weight_decay_type, 'weight_decay_type', ['l2', 'decoupled', 'stable'])
        self.validate_non_negative(eps, 'eps')

        self.lr = lr
        self.warmup_steps = warmup_steps
        self.init_lr: float = init_lr if init_lr is not None else lr / 1000.0
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'init_lr': self.init_lr,
            'beta': beta,
            'rebound': rebound,
            'weight_decay': weight_decay,
            'weight_decay_type': weight_decay_type,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ApolloDQN'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg_grad'] = torch.zeros_like(p)
                state['approx_hessian'] = torch.zeros_like(p)
                state['update'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            current_lr: float = (
                group['lr']
                if group['step'] >= self.warmup_steps
                else (self.lr - group['init_lr']) * group['step'] / self.warmup_steps + group['init_lr']
            )

            weight_decay, eps = group['weight_decay'], group['eps']

            bias_correction: float = self.debias(group['beta'], group['step'])
            alpha: float = (1.0 - group['beta']) / bias_correction

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg_grad, b, d_p = state['exp_avg_grad'], state['approx_hessian'], state['update']

                p, grad, exp_avg_grad, b, d_p = self.view_as_real(p, grad, exp_avg_grad, b, d_p)

                if weight_decay > 0.0 and group['weight_decay_type'] == 'l2':
                    grad.add_(p, alpha=weight_decay)

                delta_grad = grad - exp_avg_grad
                if group['rebound'] == 'belief':
                    rebound = delta_grad.norm(p=np.inf)
                else:
                    rebound = 1e-2
                    eps /= rebound

                exp_avg_grad.add_(delta_grad, alpha=alpha)

                de_nom = d_p.norm(p=4).add_(eps)
                d_p.div_(de_nom)

                v_sq = d_p.mul(d_p)
                delta = delta_grad.div_(de_nom).mul_(d_p).sum().mul(-alpha) - b.mul(v_sq).sum()

                b.addcmul_(v_sq, delta)

                de_nom = b.abs().clamp_(min=rebound)
                if group['rebound'] == 'belief':
                    de_nom.add_(eps / alpha)

                d_p.copy_(exp_avg_grad.div(de_nom))

                if weight_decay > 0.0 and group['weight_decay_type'] != 'l2':
                    if group['weight_decay_type'] == 'stable':
                        weight_decay /= de_nom.mean().item()

                    d_p.add_(p, alpha=weight_decay)

                p.add_(d_p, alpha=-current_lr)

        return loss

ASGD

Bases: BaseOptimizer

Adaptive SGD with estimation of the local smoothness (curvature).

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.01
amplifier float

amplifier.

0.02
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
theta float

theta.

1.0
dampening float

dampening for momentum.

1.0
eps float

term added to denominator to improve numerical stability.

1e-05
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
class ASGD(BaseOptimizer):
    """Adaptive SGD with estimation of the local smoothness (curvature).

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        amplifier (float): amplifier.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        theta (float): theta.
        dampening (float): dampening for momentum.
        eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        amplifier: float = 0.02,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        theta: float = 1.0,
        dampening: float = 1.0,
        eps: float = 1e-5,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(amplifier, 'amplifier')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'amplifier': amplifier,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'theta': theta,
            'dampening': dampening,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ASGD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        pass

    @staticmethod
    def get_norms_by_group(group: ParamGroup, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get parameter & gradient norm by group."""
        p_norm = torch.zeros(1, dtype=torch.float32, device=device)
        g_norm = torch.zeros(1, dtype=torch.float32, device=device)

        for p in group['params']:
            if p.grad is None:
                continue

            p_norm.add_(p.norm().pow(2))
            g_norm.add_(p.grad.norm().pow(2))

        p_norm.sqrt_()
        g_norm.sqrt_()

        return p_norm, g_norm

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            device = group['params'][0].device

            if 'prev_param_norm' not in group and 'prev_grad_norm' not in group:
                group['prev_param_norm'], group['prev_grad_norm'] = self.get_norms_by_group(group, device)

            group['curr_param_norm'], group['curr_grad_norm'] = self.get_norms_by_group(group, device)

            param_diff_norm: float = (group['curr_param_norm'] - group['prev_param_norm']).item()
            grad_diff_norm: float = (group['curr_grad_norm'] - group['prev_grad_norm']).item()

            new_lr: float = group['lr'] * math.sqrt(1 + group['amplifier'] * group['theta'])
            if param_diff_norm > 0 and grad_diff_norm > 0:
                new_lr = min(new_lr, param_diff_norm / (group['dampening'] * grad_diff_norm)) + group['eps']

            group['theta'] = new_lr / group['lr']
            group['lr'] = new_lr

            group['prev_param_norm'].copy_(group['curr_param_norm'])
            group['prev_grad_norm'].copy_(group['curr_grad_norm'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(grad, alpha=-new_lr)

        return loss

get_norms_by_group(group, device) staticmethod

Get parameter & gradient norm by group.

Source code in pytorch_optimizer/optimizer/sgd.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@staticmethod
def get_norms_by_group(group: ParamGroup, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get parameter & gradient norm by group."""
    p_norm = torch.zeros(1, dtype=torch.float32, device=device)
    g_norm = torch.zeros(1, dtype=torch.float32, device=device)

    for p in group['params']:
        if p.grad is None:
            continue

        p_norm.add_(p.norm().pow(2))
        g_norm.add_(p.grad.norm().pow(2))

    p_norm.sqrt_()
    g_norm.sqrt_()

    return p_norm, g_norm

AvaGrad

Bases: BaseOptimizer

Domain-independent Dominance of Adaptive Methods.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.1
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

0.1
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/avagrad.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class AvaGrad(BaseOptimizer):
    """Domain-independent Dominance of Adaptive Methods.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-1,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'gamma': None,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AvaGrad'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
            prev_bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'] - 1))

            step_size: float = group['lr']
            if group['step'] > 1:
                step_size: float = self.apply_adam_debias(
                    adam_debias=group.get('adam_debias', False),
                    step_size=group['gamma'] * group['lr'],
                    bias_correction1=bias_correction1,
                )

            squared_norm: float = 0.0
            num_params: float = 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                sqrt_exp_avg_sq = exp_avg_sq.sqrt()

                if group['step'] > 1:
                    de_nom = sqrt_exp_avg_sq.div(prev_bias_correction2_sq).add_(group['eps'])

                    p.addcdiv_(exp_avg, de_nom, value=-step_size)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                param_wise_lr = sqrt_exp_avg_sq.div_(bias_correction2_sq).add_(group['eps'])
                squared_norm += param_wise_lr.norm(-2) ** -2
                num_params += param_wise_lr.numel()

            group['gamma'] = 0.0 if num_params == 0.0 else 1.0 / math.sqrt(squared_norm / num_params)

        return loss

BCOS

Bases: BaseOptimizer

Stochastic Approximation with Block Coordinate Optimal Stepsizes.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
beta float

smoothing factor in computing the momentum and EMA estimators.

0.9
beta2 Optional[float]
None
mode Mode

algorithmic mode of BCOS, must be one of the three choices. 'g': use gradient as search direction and EMA estimator for its 2nd moment (equivalent to RMSprop). 'm': use momentum as search direction and EMA estimator for its 2nd moment (using same beta). 'c': use momentum as search direction and conditional estimator for its 2nd moment.

'c'
simple_cond bool

whether use simple alternative in BCOS-c variant.

False
weight_decay float

weight decay regularization strength.

0.1
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/bcos.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class BCOS(BaseOptimizer):
    """Stochastic Approximation with Block Coordinate Optimal Stepsizes.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): smoothing factor in computing the momentum and EMA estimators.
        beta2 (Optional[float]):
        mode (Mode): algorithmic mode of BCOS, must be one of the three choices.
            'g': use gradient as search direction and EMA estimator for its 2nd moment (equivalent to RMSprop).
            'm': use momentum as search direction and EMA estimator for its 2nd moment (using same beta).
            'c': use momentum as search direction and conditional estimator for its 2nd moment.
        simple_cond (bool): whether use simple alternative in BCOS-c variant.
        weight_decay (float): weight decay regularization strength.
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        beta: float = 0.9,
        beta2: Optional[float] = None,
        mode: Mode = 'c',
        simple_cond: bool = False,
        weight_decay: float = 0.1,
        weight_decouple: bool = True,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0)
        self.validate_options(mode, 'mode', ['g', 'm', 'c'])
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.mode = mode
        self.simple_cond = simple_cond
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'beta': beta,
            'beta2': beta2,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
            **kwargs,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'BCOS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if self.mode in ('m', 'c') and 'm' not in state:
                state['m'] = grad.clone()

            if self.mode in ('g', 'm') and 'v' not in state:
                state['v'] = grad.square()

    def compute_v(self, grad: torch.Tensor, m: torch.Tensor, beta: float, beta2: Optional[float]) -> torch.Tensor:
        g2 = grad.square()

        if self.simple_cond:
            beta_v: float = 1.0 - (1.0 - beta) ** 2 if beta2 is None else beta2
            return beta_v * m.square() + (1.0 - beta_v) * g2

        return (
            (3.0 * beta ** 2 - 2.0 * beta ** 3) * m.square()
            + (1.0 - beta) ** 2 * g2
            + 2.0 * beta * (1.0 - beta) ** 2 * m * grad
        )  # fmt: skip

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta, beta2 = group['beta'], group['beta2']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                old_m: Optional[torch.Tensor] = state.get('m', None)

                if self.mode in ('m', 'c'):
                    m = state['m']
                    m.mul_(beta).add_(grad, alpha=1.0 - beta)
                    d = m
                else:
                    d = grad

                if self.mode in ('g', 'm'):
                    beta_v: float = beta if beta2 is None else beta2

                    v = state['v']
                    v.mul_(beta_v).add_(d.square(), alpha=1.0 - beta_v)
                else:
                    v: torch.Tensor = self.compute_v(grad, old_m, beta, beta2)

                p.addcdiv_(d, v.sqrt().add_(group['eps']), value=-group['lr'])

        return loss

BSAM

Bases: BaseOptimizer

SAM as an Optimal Relaxation of Bayes.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
num_data int

number of training data.

required
lr float

learning rate.

0.5
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

weight decay (L2 penalty).

0.0001
rho float

size of the neighborhood for computing the max loss.

0.05
adaptive bool

element-wise Adaptive SAM.

False
damping float

damping to stabilize the method.

0.1
kwargs Dict

parameters for optimizer.

{}
Example
model = YourModel()
optimizer = BSAM(model.parameters(), ...)

def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

for input, output in data:
    loss = loss_function(output, model(input))
    loss.backward()

    optimizer.step(closure)
    optimizer.zero_grad()
Source code in pytorch_optimizer/optimizer/sam.py
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
class BSAM(BaseOptimizer):
    """SAM as an Optimal Relaxation of Bayes.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        num_data (int): number of training data.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): weight decay (L2 penalty).
        rho (float): size of the neighborhood for computing the max loss.
        adaptive (bool): element-wise Adaptive SAM.
        damping (float): damping to stabilize the method.
        kwargs (Dict): parameters for optimizer.

    Example:
        ```python
        model = YourModel()
        optimizer = BSAM(model.parameters(), ...)

        def closure():
            loss = loss_function(output, model(input))
            loss.backward()
            return loss

        for input, output in data:
            loss = loss_function(output, model(input))
            loss.backward()

            optimizer.step(closure)
            optimizer.zero_grad()
        ```

    """

    def __init__(
        self,
        params: ParamsT,
        num_data: int,
        lr: float = 5e-1,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 1e-4,
        rho: float = 0.05,
        adaptive: bool = False,
        damping: float = 0.1,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(rho, 'rho')
        self.validate_non_negative(num_data, 'num_data')
        self.validate_non_negative(damping, 'damping')

        self.num_data = num_data
        self.damping = damping

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'rho': rho,
            'adaptive': adaptive,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'bSAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            state = self.state[p]

            if 's' not in state:
                state['s'] = torch.ones_like(p)
                state['noisy_gradient'] = torch.zeros_like(p.grad)
                state['momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def first_step(self):
        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                noise = torch.normal(0.0, 1 / (self.num_data * state['s']))

                p.add_(noise)

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['noisy_gradient'] = p.grad.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * group['rho'] * p.grad / state['s']

                p.add_(e_w)

    @torch.no_grad()
    def third_step(self):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                momentum, s = state['momentum'], state['s']
                momentum.mul_(beta1).add_(p.grad * weight_decay, alpha=1.0 - beta1)

                var = (torch.sqrt(s).mul_(p.grad.abs()).add_(weight_decay + self.damping)).pow_(2)
                s.mul_(beta2).add_(var, alpha=1.0 - beta2)

                p.add_(momentum / s, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step()

        with torch.enable_grad():
            closure()

        self.second_step()

        with torch.enable_grad():
            loss = closure()

        self.third_step()

        return loss

CAME

Bases: BaseOptimizer

Confidence-guided Adaptive Memory Efficient Optimization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0002
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999, 0.9999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
clip_threshold float

Threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

Whether to use the AMSBound variant.

False
eps1 float

Term added to the denominator to improve numerical stability.

1e-30
eps2 float

Term added to the denominator to improve numerical stability.

1e-16
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/came.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class CAME(BaseOptimizer):
    """Confidence-guided Adaptive Memory Efficient Optimization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        clip_threshold (float): Threshold of root-mean-square of final gradient update.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps1 (float): Term added to the denominator to improve numerical stability.
        eps2 (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2e-4,
        betas: Betas = (0.9, 0.999, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.clip_threshold = clip_threshold
        self.eps1 = eps1
        self.eps2 = eps2
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps1': eps1,
            'eps2': eps2,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'CAME'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_shape: Tuple[int, ...] = grad.shape
            factored: bool = self.get_options(grad_shape)

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                    state['exp_avg_res_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_res_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                state['RMS'] = 0.0

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    @staticmethod
    def get_rms(x: torch.Tensor) -> torch.Tensor:
        r"""Get RMS."""
        return x.norm(2) / math.sqrt(x.numel())

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: torch.Tensor,
        exp_avg_sq_col: torch.Tensor,
        output: torch.Tensor,
    ):
        r"""Get approximation of EMA of squared gradient."""
        r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                state['RMS'] = self.get_rms(p)

                update = torch.mul(grad, grad).add_(self.eps1)

                if factored:
                    exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1.0 - beta2)
                    exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1.0 - beta2)

                    self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(beta2).add_(update, alpha=1.0 - beta2)
                    torch.rsqrt(exp_avg_sq, out=update)

                if group['ams_bound']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat']
                    torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
                    torch.rsqrt(exp_avg_sq_hat / beta2, out=update)

                update.mul_(grad)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0))

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

                res = update - exp_avg
                res.pow_(2).add_(self.eps2)

                if factored:
                    exp_avg_res_row, exp_avg_res_col = state['exp_avg_res_row'], state['exp_avg_res_col']

                    exp_avg_res_row.mul_(beta3).add_(res.mean(dim=-1), alpha=1.0 - beta3)
                    exp_avg_res_col.mul_(beta3).add_(res.mean(dim=-2), alpha=1.0 - beta3)

                    self.approximate_sq_grad(exp_avg_res_row, exp_avg_res_col, update)
                    update.mul_(exp_avg)
                else:
                    update = exp_avg

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                update.mul_(group['lr'])

                p.add_(-update)

        return loss

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/optimizer/came.py
121
122
123
124
125
126
127
128
129
130
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: torch.Tensor,
    exp_avg_sq_col: torch.Tensor,
    output: torch.Tensor,
):
    r"""Get approximation of EMA of squared gradient."""
    r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
    c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
    torch.mul(r_factor, c_factor, out=output)

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/came.py
111
112
113
114
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/optimizer/came.py
116
117
118
119
@staticmethod
def get_rms(x: torch.Tensor) -> torch.Tensor:
    r"""Get RMS."""
    return x.norm(2) / math.sqrt(x.numel())

centralize_gradient(grad, gc_conv_only=False)

Gradient Centralization (GC).

Parameters:

Name Type Description Default
grad Tensor

Gradient tensor.

required
gc_conv_only bool

If False, apply GC to both convolutional and fully connected layers; if True, apply only to convolutional layers.

False
Source code in pytorch_optimizer/optimizer/gradient_centralization.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def centralize_gradient(grad: torch.Tensor, gc_conv_only: bool = False) -> None:
    """Gradient Centralization (GC).

    Args:
        grad (torch.Tensor): Gradient tensor.
        gc_conv_only (bool): If False, apply GC to both convolutional and fully connected layers; if True, apply only
            to convolutional layers.

    """
    size: int = grad.dim()
    if (gc_conv_only and size > 3) or (not gc_conv_only and size > 1):
        grad.add_(-grad.mean(dim=tuple(range(1, size)), keepdim=True))

Conda

Bases: BaseOptimizer

Column-Normalized Adam for Training Large Language Models Faster.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
update_proj_gap int

Update projection gap.

2000
scale float

Galore projection scaling factor.

1.0
projection_type PROJECTION_TYPE

The type of the projection.

'std'
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/conda.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class Conda(BaseOptimizer):
    """Column-Normalized Adam for Training Large Language Models Faster.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        update_proj_gap (int): Update projection gap.
        scale (float): Galore projection scaling factor.
        projection_type (PROJECTION_TYPE): The type of the projection.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        update_proj_gap: int = 2000,
        scale: float = 1.0,
        projection_type: PROJECTION_TYPE = 'std',
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(update_proj_gap, 'update_proj_gap')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'update_proj_gap': update_proj_gap,
            'scale': scale,
            'projection_type': projection_type,
            'eps': eps,
            **kwargs,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Conda'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if p.dim() == 2:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=None,
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'], exp_avg)
                    exp_avg = state['projector'].project(exp_avg, group['step'])

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom

                if p.dim() == 2:
                    norm_grad = state['projector'].project_back(norm_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        return loss

DAdaptAdaGrad

Bases: BaseOptimizer

AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
momentum float

Momentum factor.

0.0
d0 float

Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

Prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

0.0
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class DAdaptAdaGrad(BaseOptimizer):
    """AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Momentum factor.
        d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        growth_rate (float): Prevent the D estimate from growing faster than this multiplicative rate.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        momentum: float = 0.0,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'k': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdaGrad'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'alpha_k' not in state:
                state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
                state['sk'] = torch.zeros_like(p)
                state['x0'] = torch.clone(p)
                if p.grad.is_sparse:
                    state['weighted_sk'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:  # noqa: PLR0912, PLR0915
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        g_sq = torch.tensor([0.0], device=device)
        sk_sq_weighted_change = torch.tensor([0.0], device=device)
        sk_l1_change = torch.tensor([0.0], device=device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_sq_weighted' not in group:
            group['sk_sq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_l1' not in group:
            group['sk_l1'] = torch.tensor([0.0], device=device)

        gsq_weighted = group['gsq_weighted']
        sk_sq_weighted = group['sk_sq_weighted']
        sk_l1 = group['sk_l1']

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                sk, alpha_k = state['sk'], state['alpha_k']

                if grad.is_sparse:
                    weighted_sk = state['weighted_sk']

                    grad = grad.coalesce()

                    vk = grad._values().pow(2)
                    sk_masked = sk.sparse_mask(grad).coalesce()
                    old_sk_l1_masked = sk_masked._values().abs().sum()

                    sk.add_(grad, alpha=d_lr)

                    sk_masked = sk.sparse_mask(grad).coalesce()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()
                    weighted_sk_masked = weighted_sk.sparse_mask(grad).coalesce()

                    # update alpha before step
                    alpha_k_p1_masked = alpha_k_masked._values() + vk

                    alpha_k_delta_masked = alpha_k_p1_masked - alpha_k_masked._values()
                    alpha_k_delta = torch.sparse_coo_tensor(grad.indices(), alpha_k_delta_masked, grad.shape)
                    alpha_k.add_(alpha_k_delta)

                    de_nom = torch.sqrt(alpha_k_p1_masked + eps)

                    grad_sq = vk.div(de_nom).sum()
                    g_sq.add_(grad_sq)

                    # update weighted sk sq tracking
                    weighted_sk_p1_masked = sk_masked._values().pow(2).div(de_nom)

                    sk_sq_weighted_change.add_(weighted_sk_p1_masked.sum() - weighted_sk_masked._values().sum())

                    weighted_sk_p1_delta_masked = weighted_sk_p1_masked - weighted_sk_masked._values()
                    weighted_sk_p1_delta = torch.sparse_coo_tensor(
                        grad.indices(), weighted_sk_p1_delta_masked, grad.shape
                    )
                    weighted_sk.add_(weighted_sk_p1_delta)

                    sk_l1_masked = sk_masked._values().abs().sum()
                    sk_l1_change.add_(sk_l1_masked - old_sk_l1_masked)
                else:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                    old_sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    old_sk_l1_param = sk.abs().sum()

                    alpha_k.add_(grad.pow(2))
                    grad_sq = grad.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    g_sq.add_(grad_sq)

                    sk.add_(grad, alpha=d_lr)

                    sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    sk_l1_param = sk.abs().sum()

                    sk_sq_weighted_change.add_(sk_sq_weighted_param - old_sk_sq_weighted_param)
                    sk_l1_change.add_(sk_l1_param - old_sk_l1_param)

        sk_sq_weighted.add_(sk_sq_weighted_change)
        gsq_weighted.add_(g_sq, alpha=d_lr ** 2)  # fmt: skip
        sk_l1.add_(sk_l1_change)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat = (sk_sq_weighted - gsq_weighted) / sk_l1
            d = group['d'] = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['gsq_weighted'] = gsq_weighted
            group['sk_sq_weighted'] = sk_sq_weighted
            group['sk_l1'] = sk_l1
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                alpha_k, sk, x0 = state['alpha_k'], state['sk'], state['x0']

                if grad.is_sparse:
                    grad = grad.coalesce()

                    sk_masked = sk.sparse_mask(grad).coalesce()._values()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()._values()
                    x0_masked = x0.sparse_mask(grad).coalesce()._values()
                    p_masked = p.sparse_mask(grad).coalesce()._values()

                    loc_masked = x0_masked - sk_masked.div(torch.sqrt(alpha_k_masked + group['eps']))

                    loc_delta_masked = loc_masked - p_masked
                    loc_delta = torch.sparse_coo_tensor(grad.indices(), loc_delta_masked, grad.shape)
                    p.add_(loc_delta)
                else:
                    z = x0 - sk.div(alpha_k.sqrt().add_(group['eps']))

                    if group['momentum'] > 0.0:
                        p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])
                    else:
                        p.copy_(z)

            group['k'] += 1

        return loss

DAdaptAdam

Bases: BaseOptimizer

Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
betas Betas

Betas.

(0.9, 0.999)
d0 float

Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

Prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Use AdamW style weight decay.

False
fixed_decay bool

Fix weight decay.

False
bias_correction bool

Turn on Adam's bias correction.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class DAdaptAdam(BaseOptimizer):
    """Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Betas.
        d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        growth_rate (float): Prevent the D estimate from growing faster than this multiplicative rate.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Use AdamW style weight decay.
        fixed_decay (bool): Fix weight decay.
        bias_correction (bool): Turn on Adam's bias correction.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        betas: Betas = (0.9, 0.999),
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'step': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        beta1, beta2 = group['betas']

        beta2_sq: float = math.sqrt(beta2)

        d: float = group['d']
        lr: float = group['lr']

        bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
        bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
        bias_correction: float = bias_correction1 / bias_correction2_sq

        d_lr: float = self.apply_adam_debias(
            not group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
        )

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_acc = torch.tensor([0.0], device=device)

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, s = state['exp_avg'], state['exp_avg_sq'], state['s']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                numerator_acc.add_(torch.dot(grad.flatten(), s.div(de_nom).flatten()), alpha=d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                s.mul_(beta2_sq).add_(grad, alpha=d_lr * (1.0 - beta2_sq))

                sk_l1.add_(s.abs().sum())

        if sk_l1 == 0:
            return loss

        numerator_weighted.mul_(beta2_sq).add_(numerator_acc, alpha=1.0 - beta2_sq)  # fmt: skip

        if lr > 0.0:
            d_hat = numerator_weighted / (1.0 - beta2_sq) * sk_l1
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.addcdiv_(exp_avg, de_nom, value=-1.0)

        return loss

DAdaptAdan

Bases: BaseOptimizer

Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
betas Betas

(Betas). coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Decoupled weight decay.

False
d0 float

Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

Prevent the D estimate from growing faster than this multiplicative rate. Default is inf, for unrestricted.

float('inf')
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
class DAdaptAdan(BaseOptimizer):
    """Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas: (Betas). coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Decoupled weight decay.
        d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        growth_rate (float): Prevent the D estimate from growing faster than this multiplicative rate.
            Default is inf, for unrestricted.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        betas: Betas = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'd': d0,
            'growth_rate': growth_rate,
            'k': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdan'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)
                state['previous_grad'] = -grad.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]

        beta1, beta2, beta3 = group['betas']
        growth_rate = group['growth_rate']

        d, lr = group['d'], group['lr']
        d_lr = float(d * lr)

        g_sq = torch.tensor([0.0], device=group['params'][0].device)
        sk_sq_weighted = torch.tensor([0.0], device=group['params'][0].device)
        sk_l1 = torch.tensor([0.0], device=group['params'][0].device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=group['params'][0].device)
        gsq_weighted = group['gsq_weighted']

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_diff = state['previous_grad']
                grad_diff.add_(grad)

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=d_lr * (1.0 - beta2))

                grad_diff.mul_(beta2).add_(grad)
                grad_diff = to_real(grad_diff * grad_diff.conj())
                exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

                grad_power = to_real(grad * grad.conj())
                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                g_sq.add_(grad_power.div_(de_nom).sum())

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=d_lr * (1.0 - beta3))

                sk_sq_weighted.add_(to_real(s * s.conj()).div_(de_nom).sum())
                sk_l1.add_(s.abs().sum())

                state['previous_grad'].copy_(-grad)

        if sk_l1 == 0:
            return loss

        gsq_weighted.mul_(beta3).add_(g_sq, alpha=(d_lr ** 2) * (1.0 - beta3))  # fmt: skip

        if lr > 0.0:
            d_hat = (sk_sq_weighted / (1.0 - beta3) - gsq_weighted) / sk_l1
            d = max(d, min(d_hat, d * growth_rate))

        for group in self.param_groups:
            group['step'] += 1

            group['gsq_weighted'] = gsq_weighted
            group['d'] = d
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if group['weight_decouple']:
                    p.mul_(1.0 - d_lr * group['weight_decay'])

                p.addcdiv_(exp_avg, de_nom, value=-1.0)
                p.addcdiv_(exp_avg_diff, de_nom, value=-beta2)

                if not group['weight_decouple']:
                    p.div_(1.0 + d_lr * group['weight_decay'])

            group['k'] += 1

        return loss

DAdaptLion

Bases: BaseOptimizer

Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
betas Betas

(Betas). Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
d0 float

Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
class DAdaptLion(BaseOptimizer):
    """Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas: (Betas). Coefficients used for computing running averages of gradient and the squared Hessian trace.
        d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        betas: Betas = (0.9, 0.999),
        d0: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptLion'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_accumulator = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta2_sq = math.sqrt(beta2)

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, s = state['exp_avg'], state['s']

                update = exp_avg.clone().mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
                p.add_(update, alpha=-d_lr)

                exp_avg.mul_(beta2).add_(grad, alpha=(1.0 - beta2) * d_lr)

                numerator_accumulator.add_(torch.dot(update.flatten(), s.flatten()), alpha=d_lr)
                s.mul_(beta2_sq).add_(update, alpha=(1.0 - beta2_sq) * d_lr)

                sk_l1.add_(s.abs().sum())

        numerator_weighted.mul_(beta2_sq).add_(numerator_accumulator, alpha=1.0 - beta2_sq)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat: float = (numerator_weighted / ((1.0 - beta2_sq) * sk_l1)).item()
            d = max(d, d_hat)

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

        return loss

DAdaptSGD

Bases: BaseOptimizer

SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

1.0
momentum float

Momentum.

0.9
d0 float

Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

Prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
class DAdaptSGD(BaseOptimizer):
    """SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Momentum.
        d0 (float): Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        growth_rate (float): Prevent the D estimate from growing faster than this multiplicative rate.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        momentum: float = 0.9,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptSGD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['s'] = torch.zeros_like(p)
                state['x0'] = p.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        sk_sq = torch.tensor([0.0], device=device)
        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        if group['step'] == 0:
            group['g0_norm'] = get_global_gradient_norm(self.param_groups).sqrt_().item()
        g0_norm = group['g0_norm']

        if g0_norm == 0:
            return loss

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr / g0_norm

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s = state['s']
                numerator_weighted.add_(torch.dot(grad.flatten(), s.flatten()), alpha=d_lr)

                s.add_(grad, alpha=d_lr)
                sk_sq.add_(s.pow(2).sum())

        if lr > 0.0:
            d_hat = 2.0 * numerator_weighted / sk_sq.sqrt()
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                z = state['z']
                z.copy_(state['x0'] - state['s'])

                p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])

        return loss

DeMo

Bases: SGD, BaseOptimizer

Decoupled Momentum Optimization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
compression_decay float

Compression decay.

0.999
compression_top_k int

Compression top-k.

32
compression_chunk int

Compression chunk size.

64
weight_decay float

Weight decay (L2 penalty).

0.0
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/demo.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
class DeMo(torch.optim.SGD, BaseOptimizer):  # pragma: no cover
    """Decoupled Momentum Optimization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        compression_decay (float): Compression decay.
        compression_top_k (int): Compression top-k.
        compression_chunk (int): Compression chunk size.
        weight_decay (float): Weight decay (L2 penalty).
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        compression_decay: float = 0.999,
        compression_top_k: int = 32,
        compression_chunk: int = 64,
        weight_decay: float = 0.0,
        process_group: Optional[ProcessGroup] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(compression_decay, 'compression_decay', 0.0, 1.0, range_type='[)')
        self.validate_positive(compression_top_k, 'compression_top_k')
        self.validate_positive(compression_chunk, 'compression_chunk')

        self.weight_decay = weight_decay

        self.compression_decay = compression_decay
        self.compression_top_k = compression_top_k
        self.compression_chunk = compression_chunk
        self.process_group = process_group

        self.data_transmit: int = 0
        self.data_receive: int = 0

        self.maximize = maximize

        super().__init__(
            params,
            lr=lr,
            foreach=False,
            momentum=0.0,
            dampening=0.0,
            nesterov=False,
            maximize=False,
            weight_decay=0.0,
            **kwargs,
        )

        self.demo_state = {}
        self.init_demo_states()
        self.init_parameters()

        self.default_dtype: torch.dtype = self.find_dtype()
        self.transform = TransformDCT(self.param_groups, self.compression_chunk, norm='ortho')
        self.compress = CompressDCT()

    def __str__(self) -> str:
        return 'DeMo'

    def find_dtype(self) -> torch.dtype:
        r"""Return dtype of the parameter."""
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    return p.dtype
        return torch.float32

    def init_demo_states(self) -> None:
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self.demo_state[p] = {}

    def init_parameters(self) -> None:
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.requires_grad:
                    state = self.demo_state.get(p, {})

                    state['delta'] = torch.zeros_like(p)

    def demo_all_gather(self, sparse_idx, sparse_val):
        world_size: int = get_world_size() if self.process_group is None else self.process_group.size()

        sparse_idx_list = [torch.zeros_like(sparse_idx) for _ in range(world_size)]
        sparse_val_list = [torch.zeros_like(sparse_val) for _ in range(world_size)]

        sparse_idx_handle = all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True)
        sparse_val_handle = all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True)

        sparse_idx_handle.wait()
        sparse_val_handle.wait()

        return sparse_idx_list, sparse_val_list

    @torch.no_grad()
    def init_group(self):
        pass

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        self.data_transmit = 0
        self.data_receive = 0

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.demo_state.get(p, {})

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=lr,
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                if self.compression_decay != 1:
                    state['delta'].mul_(self.compression_decay)

                state['delta'].add_(grad, alpha=lr)

                sparse_idx, sparse_val, x_shape = self.compress.compress(
                    self.transform.encode(state['delta']), self.compression_top_k
                )

                transmit_grad = self.transform.decode(self.compress.decompress(p, sparse_idx, sparse_val, x_shape))

                state['delta'].sub_(transmit_grad)

                sparse_idx_gather, sparse_val_gather = self.demo_all_gather(sparse_idx, sparse_val)

                self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes
                for si, v in zip(sparse_idx_gather, sparse_val_gather):
                    self.data_receive += si.nbytes + v.nbytes

                new_grad = self.transform.decode(
                    self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, x_shape)
                )

                if p.grad is None:
                    p.grad = new_grad
                else:
                    p.grad.copy_(new_grad)

                p.grad.sign_()

        return super().step(closure)

find_dtype()

Return dtype of the parameter.

Source code in pytorch_optimizer/optimizer/demo.py
358
359
360
361
362
363
364
def find_dtype(self) -> torch.dtype:
    r"""Return dtype of the parameter."""
    for group in self.param_groups:
        for p in group['params']:
            if p.requires_grad:
                return p.dtype
    return torch.float32

DiffGrad

Bases: BaseOptimizer

An Optimization Method for Convolutional Neural Networks.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
rectify bool

Perform the rectified update similar to RAdam.

False
n_sma_threshold int

Recommended is 5.

5
degenerated_to_sgd bool

Degenerated to SGD.

True
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/diffgrad.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class DiffGrad(BaseOptimizer):
    """An Optimization Method for Convolutional Neural Networks.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        rectify (bool): Perform the rectified update similar to RAdam.
        n_sma_threshold (int): Recommended is 5.
        degenerated_to_sgd (bool): Degenerated to SGD.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'diffGrad'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['previous_grad'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']

                p, grad, exp_avg, exp_avg_sq, previous_grad = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, previous_grad
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                dfc = previous_grad.clone()
                dfc.sub_(grad).abs_().sigmoid_().mul_(exp_avg)
                state['previous_grad'].copy_(
                    torch.view_as_complex(grad) if torch.is_complex(state['previous_grad']) else grad
                )

                if not group['rectify']:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(dfc, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

DistributedMuon

Bases: BaseOptimizer

Momentum Orthogonalized by Newton-schulz.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.

Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

Parameters:

Name Type Description Default
params ParamsT

The parameters to be optimized by Muon.

required
lr float

Learning rate.

0.02
momentum float

The momentum used by the internal SGD.

0.95
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
nesterov bool

Whether to use nesterov momentum.

True
ns_steps int

The number of Newton-Schulz iterations to run. (5 is probably always enough)

5
ns_coeffs NewtonSchulzWeights

Newton-Schulz coefficients or preset name.

'original'
use_adjusted_lr bool

Whether to use adjusted learning rate, which is from the Moonlight. Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

False
adamw_lr float

The learning rate for the internal AdamW.

0.0003
adamw_betas tuple

The betas for the internal AdamW.

(0.9, 0.95)
adamw_wd float

The weight decay for the internal AdamW.

0.0
adamw_eps float

The epsilon for the internal AdamW.

1e-10
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Example

from pytorch_optimizer import DistributedMuon

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2] hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2] non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

param_groups = [ dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True), dict( params=hidden_gains_biases + non_hidden_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01, use_muon=False, ), ]

optimizer = DistributedMuon(param_groups)

Source code in pytorch_optimizer/optimizer/muon.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
class DistributedMuon(BaseOptimizer):  # pragma: no cover
    """Momentum Orthogonalized by Newton-schulz.

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
    each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
    update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

    Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
    scalar or vector parameters should be optimized using AdamW.

    Some warnings:
    - We believe this optimizer is unlikely to work well for training with small batch size.
    - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

    Args:
        params (ParamsT): The parameters to be optimized by Muon.
        lr (float): Learning rate.
        momentum (float): The momentum used by the internal SGD.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        nesterov (bool): Whether to use nesterov momentum.
        ns_steps (int): The number of Newton-Schulz iterations to run. (5 is probably always enough)
        ns_coeffs (NewtonSchulzWeights): Newton-Schulz coefficients or preset name.
        use_adjusted_lr (bool): Whether to use adjusted learning rate, which is from the Moonlight.
            Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
        adamw_lr (float): The learning rate for the internal AdamW.
        adamw_betas (tuple): The betas for the internal AdamW.
        adamw_wd (float): The weight decay for the internal AdamW.
        adamw_eps (float): The epsilon for the internal AdamW.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    Example:
        from pytorch_optimizer import DistributedMuon

        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

        param_groups = [
            dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True),
            dict(
                params=hidden_gains_biases + non_hidden_params,
                lr=3e-4,
                betas=(0.9, 0.95),
                weight_decay=0.01,
                use_muon=False,
            ),
        ]

        optimizer = DistributedMuon(param_groups)

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2e-2,
        momentum: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        nesterov: bool = True,
        ns_steps: int = 5,
        ns_coeffs: NewtonSchulzWeights = 'original',
        use_adjusted_lr: bool = False,
        adamw_lr: float = 3e-4,
        adamw_betas: Betas = (0.9, 0.95),
        adamw_wd: float = 0.0,
        adamw_eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(adamw_lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_positive(ns_steps, 'ns_steps')
        self.validate_betas(adamw_betas)
        self.validate_non_negative(adamw_wd, 'adamw_wd')
        self.validate_non_negative(adamw_eps, 'adamw_eps')
        ns_coeffs = get_newton_schulz_weights(ns_coeffs)

        self.maximize = maximize

        self.world_size: int = get_world_size()
        self.rank: int = get_rank()

        for group in params:
            group = cast(ParamGroup, group)
            if 'use_muon' not in group:
                raise ValueError('`use_muon` must be set.')

            if group['use_muon']:
                group['lr'] = group.get('lr', lr)
                group['momentum'] = group.get('momentum', momentum)
                group['nesterov'] = group.get('nesterov', nesterov)
                group['weight_decay'] = group.get('weight_decay', weight_decay)
                group['ns_steps'] = group.get('ns_steps', ns_steps)
                group['ns_coeffs'] = get_newton_schulz_weights(group.get('ns_coeffs', ns_coeffs))
                group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr)
            else:
                group['lr'] = group.get('lr', adamw_lr)
                group['betas'] = group.get('betas', adamw_betas)
                group['eps'] = group.get('eps', adamw_eps)
                group['weight_decay'] = group.get('weight_decay', adamw_wd)

            group['weight_decouple'] = group.get('weight_decouple', weight_decouple)

        super().__init__(params, kwargs)

    def __str__(self) -> str:
        return 'DistributedMuon'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                p.grad = torch.zeros_like(p)

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0 and not group['use_muon']:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if group['use_muon']:
                params = group['params']
                padded_params = params + [torch.empty_like(params[-1])] * (
                    self.world_size - len(params) % self.world_size
                )

                for i in range(len(params))[:: self.world_size]:
                    if i + self.rank < len(params):
                        p = params[i + self.rank]

                        grad = p.grad

                        self.maximize_gradient(grad, maximize=self.maximize)

                        state = self.state[p]
                        if len(state) == 0:
                            state['momentum_buffer'] = torch.zeros_like(p)

                        self.apply_weight_decay(
                            p,
                            grad=grad,
                            lr=group['lr'],
                            weight_decay=group['weight_decay'],
                            weight_decouple=group['weight_decouple'],
                            fixed_decay=False,
                        )

                        buf = state['momentum_buffer']
                        buf.lerp_(grad, weight=1.0 - group['momentum'])

                        update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf
                        if update.ndim > 2:
                            update = update.view(len(update), -1)

                        update = zero_power_via_newton_schulz_5(
                            update, num_steps=group['ns_steps'], weights=group['ns_coeffs']
                        )

                        if group.get('cautious'):
                            self.apply_cautious(update, grad)

                        lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])

                        p.add_(update.reshape(p.shape), alpha=-lr)

                    all_gather(padded_params[i:i + self.world_size], padded_params[i:i + self.rank])  # fmt: skip
            else:
                for p in group['params']:
                    grad = p.grad

                    state = self.state[p]
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                    beta1, beta2 = group['betas']

                    bias_correction1: float = self.debias(beta1, group['step'])
                    bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

                    exp_avg.lerp_(grad, weight=1.0 - beta1)
                    exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2)

                    de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(bias_correction2_sq)

                    p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

DualAdam

Bases: BaseOptimizer

Combining Adam and its inverse counterpart to enhance generalization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and squared gradient.

(0.9, 0.999)
switch_rate float

Linear decay rate for the inverse Adam update contribution.

0.01
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

False
fixed_decay bool

Whether to fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dual_adam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class DualAdam(BaseOptimizer):
    """Combining Adam and its inverse counterpart to enhance generalization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and squared gradient.
        switch_rate (float): Linear decay rate for the inverse Adam update contribution.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        switch_rate: float = 1e-2,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(switch_rate, 'switch_rate', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'switch_rate': switch_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DualAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            inverse_adam_rate: float = max(0.0, 1.0 - group['step'] * group['switch_rate'])
            use_inverse_adam: bool = inverse_adam_rate >= group['switch_rate']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                exp_avg_hat = exp_avg.div(bias_correction1)
                de_nom = exp_avg_sq.div(bias_correction2).sqrt_().add_(group['eps'])

                if use_inverse_adam:
                    update = de_nom.reciprocal().mul_(1.0 - inverse_adam_rate).add_(de_nom, alpha=inverse_adam_rate)

                    p.addcmul_(exp_avg_hat, update, value=-group['lr'])
                else:
                    p.addcdiv_(exp_avg_hat, de_nom, value=-group['lr'])

        return loss

DynamicLossScaler

Dynamically adjusts the loss scaling factor.

Dynamic loss scalers are important in mixed-precision training. They help us avoid underflows and overflows in low-precision gradients.

See here for information: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling

Shamelessly stolen and adapted from FairSeq: https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py

Reference: 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

Parameters:

Name Type Description Default
init_scale float

Initial loss scale.

2.0 ** 15
scale_factor float

Factor by which to increase or decrease loss scale.

2.0
scale_window int

If no overflow occurs within scale_window iterations, the loss scale will increase by scale_factor.

2000
tolerance float

Percentage of iterations that may overflow before decreasing the loss scale.

0.0
threshold float

Minimum threshold below which the loss scale will not decrease.

None
Source code in pytorch_optimizer/optimizer/fp16.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class DynamicLossScaler:
    """Dynamically adjusts the loss scaling factor.

    Dynamic loss scalers are important in mixed-precision training.
    They help us avoid underflows and overflows in low-precision gradients.

    See here for information:
    <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>

    Shamelessly stolen and adapted from FairSeq:
    <https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>

    Reference:
    'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

    Args:
        init_scale (float): Initial loss scale.
        scale_factor (float): Factor by which to increase or decrease loss scale.
        scale_window (int): If no overflow occurs within scale_window iterations, the loss scale will increase by
            scale_factor.
        tolerance (float): Percentage of iterations that may overflow before decreasing the loss scale.
        threshold (float, optional): Minimum threshold below which the loss scale will not decrease.

    """

    def __init__(
        self,
        init_scale: float = 2.0 ** 15,
        scale_factor: float = 2.0,
        scale_window: int = 2000,
        tolerance: float = 0.00,
        threshold: Optional[float] = None,
    ):  # fmt: skip
        self.loss_scale = init_scale
        self.scale_factor = scale_factor
        self.scale_window = scale_window
        self.tolerance = tolerance
        self.threshold = threshold

        self.iter: int = 0
        self.last_overflow_iter: int = -1
        self.last_rescale_iter: int = -1
        self.overflows_since_rescale: int = 0
        self.has_overflow_serial: bool = False

    def update_scale(self, overflow: bool):
        r"""Update the loss scale.

            If overflow exceeds our tolerance, we decrease the loss scale.
            If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

        :param overflow: bool. adjust scales to prevent overflow.
        """
        iter_since_rescale: int = self.iter - self.last_rescale_iter

        if overflow:
            # calculate how often we overflowed already
            self.last_overflow_iter = self.iter
            self.overflows_since_rescale += 1

            pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
            if pct_overflow >= self.tolerance:
                # decrease loss scale by the scale factor
                self.decrease_loss_scale()

                # reset iterations
                self.last_rescale_iter = self.iter
                self.overflows_since_rescale = 0
        elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
            # increase the loss scale by scale factor
            self.loss_scale *= self.scale_factor
            self.last_rescale_iter = self.iter

        self.iter += 1

    def decrease_loss_scale(self):
        r"""Decrease the loss scale by self.scale_factor.

        NOTE: the loss_scale will not go below `self.threshold`.
        """
        self.loss_scale /= self.scale_factor
        if self.threshold is not None:
            self.loss_scale = max(self.loss_scale, self.threshold)

decrease_loss_scale()

Decrease the loss scale by self.scale_factor.

NOTE: the loss_scale will not go below self.threshold.

Source code in pytorch_optimizer/optimizer/fp16.py
86
87
88
89
90
91
92
93
def decrease_loss_scale(self):
    r"""Decrease the loss scale by self.scale_factor.

    NOTE: the loss_scale will not go below `self.threshold`.
    """
    self.loss_scale /= self.scale_factor
    if self.threshold is not None:
        self.loss_scale = max(self.loss_scale, self.threshold)

update_scale(overflow)

Update the loss scale.

If overflow exceeds our tolerance, we decrease the loss scale.
If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

:param overflow: bool. adjust scales to prevent overflow.

Source code in pytorch_optimizer/optimizer/fp16.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def update_scale(self, overflow: bool):
    r"""Update the loss scale.

        If overflow exceeds our tolerance, we decrease the loss scale.
        If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

    :param overflow: bool. adjust scales to prevent overflow.
    """
    iter_since_rescale: int = self.iter - self.last_rescale_iter

    if overflow:
        # calculate how often we overflowed already
        self.last_overflow_iter = self.iter
        self.overflows_since_rescale += 1

        pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
        if pct_overflow >= self.tolerance:
            # decrease loss scale by the scale factor
            self.decrease_loss_scale()

            # reset iterations
            self.last_rescale_iter = self.iter
            self.overflows_since_rescale = 0
    elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
        # increase the loss scale by scale factor
        self.loss_scale *= self.scale_factor
        self.last_rescale_iter = self.iter

    self.iter += 1

EmoFact

Bases: BaseOptimizer

EmoFact optimizer.

EmoFact is inspired by AdaFactor and its VRAM-friendly design is something everyone loves.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
use_shadow bool

Whether to use shadow weights or not.

False
shadow_weight float

The weight of the shadow.

0.05
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/emonavi.py
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
class EmoFact(BaseOptimizer):
    """EmoFact optimizer.

    EmoFact is inspired by AdaFactor and its VRAM-friendly design is something everyone loves.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        use_shadow (bool): Whether to use shadow weights or not.
        shadow_weight (float): The weight of the shadow.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        use_shadow: bool = False,
        shadow_weight: float = 0.05,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(shadow_weight, 'shadow_weight', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        self.lr = lr

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'use_shadow': use_shadow,
            'shadow_weight': shadow_weight,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'EmoFact'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['use_shadow']:
                    state['shadow'] = p.clone()

                shape = p.size()

                if len(shape) >= 2:
                    r_shape = [shape[0]] + [1] * (len(shape) - 1)
                    state['exp_avg_r'] = torch.zeros(r_shape, dtype=p.dtype, device=p.device)

                    c_shape = [1, *list(shape[1:])]
                    state['exp_avg_c'] = torch.zeros(c_shape, dtype=p.dtype, device=p.device)
                else:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss = 0.0
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            emo_drive, ratio, trust = get_emo_drive(self.state, loss, group['use_shadow'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['use_shadow']:
                    shadow = state['shadow']
                    if ratio > 0.0:
                        p.mul_(1.0 - ratio).add_(shadow, alpha=abs(trust))
                    else:
                        leap_ratio = 0.1 * abs(trust)
                        shadow.lerp_(p, weight=leap_ratio)

                if grad.dim() >= 2:
                    exp_avg_r, exp_avg_c = state['exp_avg_r'], state['exp_avg_c']

                    grad_p2 = grad.pow(2)
                    r_sq = (
                        torch.mean(grad_p2, dim=tuple(range(1, grad.dim())), keepdim=True).add_(group['eps']).sqrt_()
                    )
                    c_sq = torch.mean(grad_p2, dim=0, keepdim=True).add_(group['eps']).sqrt_()

                    exp_avg_r.mul_(beta1).add_(r_sq, alpha=1.0 - beta1)
                    exp_avg_c.mul_(beta1).add_(c_sq, alpha=1.0 - beta1)

                    de_nom = (exp_avg_r * exp_avg_c).sqrt_().add_(group['eps'])

                    update = grad / de_nom
                else:
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                    exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                    update = exp_avg / de_nom

                update.sign_()

                p.add_(update, alpha=-group['lr'] * emo_drive)

        self.prev_loss = loss

        return loss

EmoLynx

Bases: BaseOptimizer

EmoLynx optimizer.

Lynx was developed with inspiration from Lion and Tiger, which we deeply respect for their lightweight and intelligent design. It also integrates EmoNAVI to enhance its capabilities.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize, or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
use_shadow bool

Whether to use shadow feature.

False
shadow_weight float

The weight of the shadow.

0.05
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/emonavi.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class EmoLynx(BaseOptimizer):
    """EmoLynx optimizer.

    Lynx was developed with inspiration from Lion and Tiger, which we deeply respect for their lightweight and
    intelligent design. It also integrates EmoNAVI to enhance its capabilities.

    Args:
        params (ParamsT): Iterable of parameters to optimize, or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared hessian trace.
        use_shadow (bool): Whether to use shadow feature.
        shadow_weight (float): The weight of the shadow.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.99),
        use_shadow: bool = False,
        shadow_weight: float = 0.05,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(shadow_weight, 'shadow_weight', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'use_shadow': use_shadow,
            'shadow_weight': shadow_weight,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'EmoLynx'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['use_shadow']:
                    state['shadow'] = p.clone()
                state['exp_avg'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss = 0.0
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            emo_drive, ratio, trust = get_emo_drive(self.state, loss, group['use_shadow'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['use_shadow']:
                    shadow = state['shadow']
                    if ratio > 0.0:
                        p.mul_(1.0 - ratio).add_(shadow, alpha=abs(trust))
                    else:
                        leap_ratio = 0.1 * abs(trust)
                        shadow.lerp_(p, weight=leap_ratio)

                exp_avg = state['exp_avg']

                blended_grad = grad.mul(1.0 - beta1).add_(exp_avg, alpha=beta1).sign_()
                exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)

                p.add_(blended_grad, alpha=-group['lr'] * emo_drive)

        return loss

EmoNavi

Bases: BaseOptimizer

An emotion-driven optimizer that feels loss and navigates accordingly.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
use_shadow bool

Whether to use shadowing or not.

False
shadow_weight float

The weight of the shadow.

0.05
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/emonavi.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class EmoNavi(BaseOptimizer):
    """An emotion-driven optimizer that feels loss and navigates accordingly.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        use_shadow (bool): Whether to use shadowing or not.
        shadow_weight (float): The weight of the shadow.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        use_shadow: bool = False,
        shadow_weight: float = 0.05,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(shadow_weight, 'shadow_weight', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.use_shadow = use_shadow
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'use_shadow': use_shadow,
            'shadow_weight': shadow_weight,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'EmoNavi'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group['use_shadow']:
                    state['shadow'] = p.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss = 0.0
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            emo_drive, ratio, trust = get_emo_drive(self.state, loss, group['use_shadow'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['use_shadow']:
                    shadow = state['shadow']

                    if ratio > 0.0:
                        p.mul_(1.0 - ratio).add_(state['shadow'], alpha=abs(trust))
                        shadow.lerp_(p, weight=group['shadow_weight'])
                    else:
                        leap_ratio: float = 0.1 * abs(trust)
                        shadow.lerp_(p, weight=leap_ratio)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-group['lr'] * emo_drive)

        return loss

EXAdam

Bases: BaseOptimizer

The Power of Adaptive Cross-Moments.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/exadam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class EXAdam(BaseOptimizer):
    """The Power of Adaptive Cross-Moments.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'EXAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                d1 = 1.0 + exp_avg_sq.div(exp_avg_sq.add(group['eps'])) * (1.0 - bias_correction2)

                exp_avg_p2 = exp_avg.pow(2)
                d2 = 1.0 + exp_avg_p2.div(exp_avg_p2.add(group['eps'])) * (1.0 - bias_correction1)

                m_tilde = exp_avg.div(bias_correction1) * d1
                v_tilde = exp_avg_sq.div(bias_correction2) * d2

                g_tilde = grad.div(bias_correction1) * d1

                update = (m_tilde + g_tilde) / v_tilde.sqrt().add_(group['eps'])

                p.add_(update, alpha=-group['lr'])

        return loss

FAdam

Bases: BaseOptimizer

Adam is a natural gradient optimizer using diagonal empirical Fisher information.

Parameters:

Name Type Description Default
params ParamsT

ParamsT to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.1
clip float

Maximum norm of the gradient.

1.0
p float

Momentum factor.

0.5
eps float

Term added to the denominator to improve numerical stability.

1e-08
momentum_dtype dtype

Dtype of momentum.

float32
fim_dtype dtype

Dtype of Fisher information matrix.

float32
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fadam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class FAdam(BaseOptimizer):
    """Adam is a natural gradient optimizer using diagonal empirical Fisher information.

    Args:
        params (ParamsT): ParamsT to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        clip (float): Maximum norm of the gradient.
        p (float): Momentum factor.
        eps (float): Term added to the denominator to improve numerical stability.
        momentum_dtype (torch.dtype): Dtype of momentum.
        fim_dtype (torch.dtype): Dtype of Fisher information matrix.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.1,
        clip: float = 1.0,
        p: float = 0.5,
        eps: float = 1e-8,
        momentum_dtype: torch.dtype = torch.float32,
        fim_dtype: torch.dtype = torch.float32,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(clip, 'clip')
        self.validate_positive(p, 'p')
        self.validate_non_negative(eps, 'eps')

        self.momentum_dtype = momentum_dtype
        self.fim_dtype = fim_dtype
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'clip': clip,
            'p': p,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(p, dtype=self.momentum_dtype)
                state['fim'] = torch.zeros_like(p, dtype=self.fim_dtype)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            curr_beta2: float = self.debias_beta(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                momentum, fim = state['momentum'], state['fim']

                fim.mul_(curr_beta2).addcmul_(grad, grad, value=1.0 - curr_beta2)

                rms_grad = grad.pow(2).mean().sqrt_()
                curr_eps = min(rms_grad, 1) * group['eps']

                fim_base = fim.pow(group['p']).add_(curr_eps)
                grad_nat = grad / fim_base

                rms = grad_nat.pow(2).mean().sqrt_()
                divisor = max(1, rms) / group['clip']
                grad_nat.div_(divisor)

                momentum.mul_(beta1).add_(grad_nat, alpha=1.0 - beta1)

                grad_weights = p / fim_base

                rms = torch.pow(grad_weights, 2).mean().sqrt_()
                divisor = max(1, rms) / group['clip']
                grad_weights.div_(divisor)

                grad_weights.mul_(group['weight_decay']).add_(momentum)

                p.add_(grad_weights, alpha=-group['lr'])

        return loss

Fira

Bases: BaseOptimizer

Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint? Fira with AdamW optimizer.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fira.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class Fira(BaseOptimizer):
    """Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint? Fira with AdamW optimizer.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Fira'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if 'rank' in group and p.dim() == 2:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=group['rank'],
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'])

                if 'exp_avg' not in state:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom

                if 'rank' in group and p.dim() == 2:
                    sub_grad = state['projector'].project_back(grad)

                    norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1

                    scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
                    if norm_dim == 1:
                        scaling_factor = scaling_factor.unsqueeze(1)

                    scaling_grad = grad.sub(sub_grad).mul_(scaling_factor)

                    if 'scaling_grad' in state:
                        scaling_grad_norm = torch.norm(scaling_grad)

                        limiter = max(scaling_grad_norm / (state['scaling_grad'] + 1e-8), 1.01) / 1.01
                        scaling_grad.div_(limiter)

                        state['scaling_grad'] = scaling_grad_norm / limiter
                    else:
                        state['scaling_grad'] = torch.norm(scaling_grad)

                    norm_grad = state['projector'].project_back(norm_grad).add_(scaling_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        return loss

FlashAdamW

Bases: BaseOptimizer

FlashOptim-style AdamW with compressed optimizer states.

The optimizer mirrors FlashOptim's AdamW semantics while keeping the implementation portable for environments where Triton kernels are not available. It supports grouped 8-bit optimizer-state compression, compressed state dicts, optional low-precision master-weight error correction, and fully LR-decoupled weight decay.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and squared gradient.

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability.

1e-08
weight_decay float

Decoupled weight decay coefficient.

0.01
decouple_lr bool

Scale weight decay by lr / initial_lr instead of lr.

False
quantize bool

Store Adam moments as grouped 8-bit values plus fp16 scales.

True
compress_state_dict bool

Save quantized states in checkpoints when quantize is enabled.

True
master_weight_bits Optional[int]

Effective master-weight precision for bf16/fp16 parameters. Supports None, 24, and 32.

None
check_numerics bool

Raise if low-precision parameter updates are unlikely to alter the master weight.

False
fused bool

Placeholder for FlashOptim's Triton fused path. Currently unsupported in this portable backend.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/flash_adamw.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class FlashAdamW(BaseOptimizer):
    """FlashOptim-style AdamW with compressed optimizer states.

    The optimizer mirrors FlashOptim's AdamW semantics while keeping the implementation portable for environments where
    Triton kernels are not available. It supports grouped 8-bit optimizer-state compression, compressed state dicts,
    optional low-precision master-weight error correction, and fully LR-decoupled weight decay.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and squared gradient.
        eps (float): Term added to the denominator to improve numerical stability.
        weight_decay (float): Decoupled weight decay coefficient.
        decouple_lr (bool): Scale weight decay by ``lr / initial_lr`` instead of ``lr``.
        quantize (bool): Store Adam moments as grouped 8-bit values plus fp16 scales.
        compress_state_dict (bool): Save quantized states in checkpoints when ``quantize`` is enabled.
        master_weight_bits (Optional[int]): Effective master-weight precision for bf16/fp16 parameters. Supports
            ``None``, ``24``, and ``32``.
        check_numerics (bool): Raise if low-precision parameter updates are unlikely to alter the master weight.
        fused (bool): Placeholder for FlashOptim's Triton fused path. Currently unsupported in this portable backend.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 1e-2,
        decouple_lr: bool = False,
        quantize: bool = True,
        compress_state_dict: bool = True,
        master_weight_bits: Optional[int] = None,
        check_numerics: bool = False,
        fused: bool = False,
        maximize: bool = False,
        eps: float = 1e-8,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(eps, 'eps')
        self.validate_non_negative(weight_decay, 'weight_decay')

        if master_weight_bits not in VALID_MASTER_WEIGHT_BITS:
            raise ValueError(f'master_weight_bits must be one of {VALID_MASTER_WEIGHT_BITS}')

        if fused:
            raise NotImplementedError('FlashAdamW fused Triton kernels are not available in this portable backend')

        self.maximize = maximize
        self.compress_state_dict = compress_state_dict
        self.check_numerics = check_numerics
        self.master_byte_width = BITS_TO_BYTES[master_weight_bits]
        self.param_absmax: Dict[int, float] = {}

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'eps': eps,
            'weight_decay': weight_decay,
            'decouple_lr': decouple_lr,
            'quantize': quantize,
            'master_byte_width': self.master_byte_width,
            **kwargs,
        }

        super().__init__(params, defaults)

        for group in self.param_groups:
            group.setdefault('initial_lr', group['lr'])

        if master_weight_bits is not None and all(
            p.dtype == torch.float32 for group in self.param_groups for p in group['params']
        ):
            raise ValueError('master_weight_bits has no effect when all parameters are fp32')

    def __str__(self) -> str:
        return 'FlashAdamW'

    def maybe_check_numerics(self, p: torch.Tensor, lr: float, master_byte_width: int) -> None:
        if not self.check_numerics or p.dtype == torch.float32 or lr == 0.0:
            return

        max_abs = self.param_absmax.get(id(p))
        if max_abs is None:
            self.param_absmax[id(p)] = max_abs = float(p.detach().abs().max().item()) if p.numel() > 0 else 0.0

        if max_abs <= 0.0 or not math.isfinite(max_abs):
            return

        bits: int = max(DTYPE_WIDTHS[p.dtype], master_byte_width) * 8
        resolution: float = max_abs * 2.0 ** (-(bits - 1))

        if lr * 0.1 < resolution:
            raise ArithmeticError('learning rate is too small to update low-precision FlashAdamW parameters')

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        group.setdefault('initial_lr', group.get('lr'))

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]
            if 'exp_avg' not in state and _quantized_key('exp_avg') not in state:
                store_state(state, 'exp_avg', torch.zeros_like(p, dtype=torch.float32), group['quantize'], p.dtype)
                store_state(state, 'exp_avg_sq', torch.zeros_like(p, dtype=torch.float32), group['quantize'], p.dtype)

            error_bytes = group['master_byte_width'] - DTYPE_WIDTHS[p.dtype]
            if error_bytes > 0 and 'error_bits' not in state:
                error_dtype = torch.int8 if error_bytes == 1 else torch.int16
                state['error_bits'] = torch.zeros_like(p, dtype=error_dtype)

    @staticmethod
    def get_param_fp32(p: torch.Tensor, state: Dict[str, Any]) -> torch.Tensor:
        return reconstruct_fp32_param(p, state['error_bits']) if 'error_bits' in state else p.to(torch.float32)

    @staticmethod
    def set_param_fp32(p: torch.Tensor, state: Dict[str, Any], value: torch.Tensor, master_byte_width: int) -> None:
        p.copy_(value.to(p.dtype))
        if 'error_bits' in state:
            state['error_bits'].copy_(compute_ecc_bits(value, p, master_byte_width))

    def recompute_param_stats(self) -> None:
        for group in self.param_groups:
            for p in group['params']:
                self.param_absmax[id(p)] = float(p.detach().abs().max().item()) if p.numel() > 0 else 0.0

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                grad = p.grad.to(torch.float32)

                self.maximize_gradient(grad, maximize=self.maximize)

                self.maybe_check_numerics(p, group['lr'], group['master_byte_width'])

                exp_avg = materialize_state(state, 'exp_avg')
                exp_avg_sq = materialize_state(state, 'exp_avg_sq')

                param_fp32 = self.get_param_fp32(p, state)

                self.apply_weight_decay(
                    param_fp32,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                    ratio=1.0 / group['initial_lr'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                denominator = exp_avg_sq.div(bias_correction2).sqrt_().add_(group['eps'])
                param_fp32.addcdiv_(exp_avg.div(bias_correction1), denominator, value=-group['lr'])

                self.set_param_fp32(p, state, param_fp32, group['master_byte_width'])
                store_state(state, 'exp_avg', exp_avg, group['quantize'], p.dtype)
                store_state(state, 'exp_avg_sq', exp_avg_sq, group['quantize'], p.dtype)

        return loss

    def state_dict(self) -> Dict[str, Any]:
        state_dict = super().state_dict()
        if self.compress_state_dict:
            return state_dict

        state_dict['state'] = {param_id: dict(param_state) for param_id, param_state in state_dict['state'].items()}
        for param_state in state_dict['state'].values():
            for name in ('exp_avg', 'exp_avg_sq'):
                q_key, s_key = _quantized_key(name), _scales_key(name)
                if q_key not in param_state:
                    continue
                param_state[name] = dequantize_state(
                    param_state.pop(q_key), param_state.pop(s_key), *_state_spec(name)
                )

        return state_dict

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        super().load_state_dict(state_dict)

        for group in self.param_groups:
            group.setdefault('initial_lr', group['lr'])

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if not state:
                    continue

                for name in ('exp_avg', 'exp_avg_sq'):
                    if group['quantize'] and name in state:
                        store_state(state, name, state.pop(name).to(torch.float32), True, p.dtype)
                    elif group['quantize'] and _quantized_key(name) in state:
                        signed, _, _ = _state_spec(name)
                        quantized_dtype = torch.int8 if signed else torch.uint8
                        state[_quantized_key(name)] = state[_quantized_key(name)].to(quantized_dtype)
                        state[_scales_key(name)] = state[_scales_key(name)].to(torch.float16)
                    elif not group['quantize'] and _quantized_key(name) in state:
                        state[name] = dequantize_state(
                            state.pop(_quantized_key(name)), state.pop(_scales_key(name)), *_state_spec(name)
                        ).to(dtype=p.dtype)

    def get_fp32_model_state_dict(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        return {
            name: self.get_param_fp32(param.detach(), self.state.get(param, {})).detach().clone()
            for name, param in model.named_parameters()
        }

    @torch.inference_mode()
    def set_fp32_model_state_dict(self, model: nn.Module, state_dict: Dict[str, torch.Tensor]) -> None:
        for name, param in model.named_parameters():
            if name not in state_dict:
                continue

            state = self.state[param]

            master_byte_width = next(
                group['master_byte_width']
                for group in self.param_groups
                if any(param is grouped_param for grouped_param in group['params'])
            )
            error_bytes = master_byte_width - DTYPE_WIDTHS[param.dtype]
            if error_bytes > 0 and 'error_bits' not in state:
                error_dtype = torch.int8 if error_bytes == 1 else torch.int16
                state['error_bits'] = torch.zeros_like(param, dtype=error_dtype)

            self.set_param_fp32(param, state, state_dict[name].to(torch.float32), master_byte_width)

FOCUS

Bases: BaseOptimizer

First Order Concentrated Updating Scheme.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
betas Betas

Coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
gamma float

Controls the strength of the attraction.

0.1
weight_decay float

Weight decay (L2 penalty).

0.0
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/focus.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class FOCUS(BaseOptimizer):
    """First Order Concentrated Updating Scheme.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared hessian trace.
        gamma (float): Controls the strength of the attraction.
        weight_decay (float): Weight decay (L2 penalty).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        betas: Betas = (0.9, 0.999),
        gamma: float = 0.1,
        weight_decay: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(gamma, 'gamma', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'betas': betas, 'gamma': gamma, 'weight_decay': weight_decay}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FOCUS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['pbar'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            weight_decay: float = group['weight_decay']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, pbar = state['exp_avg'], state['pbar']

                p, grad, exp_avg, pbar = self.view_as_real(p, grad, exp_avg, pbar)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                pbar.mul_(beta2).add_(p, alpha=1.0 - beta2)

                pbar_hat = pbar / bias_correction2

                if weight_decay > 0.0:
                    p.add_(pbar_hat, alpha=-group['lr'] * weight_decay)

                update = (p - pbar_hat).sign_().mul_(group['gamma']).add_(torch.sign(exp_avg))

                p.add_(update, alpha=-group['lr'])

        return loss

FriendlySAM

Bases: BaseOptimizer

Friendly Sharpness-Aware Minimization.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

base optimizer.

required
rho float

size of the neighborhood for computing the max loss.

0.05
sigma float

sigma of FriendlySAM.

1.0
lmbda float

lambda for FriendlySAM.

0.9
adaptive bool

element-wise Adaptive SAM.

False
perturb_eps float

eps for perturbation.

1e-12
kwargs Dict

parameters for optimizer.

{}
Example
model = YourModel()
base_optimizer = Ranger21
optimizer = FriendlySAM(model.parameters(), base_optimizer)

for input, output in data:
    # first forward-backward pass

    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    # make sure to do a full forward pass
    loss_function(output, model(input)).backward()
    optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

model = YourModel()
base_optimizer = Ranger21
optimizer = FriendlySAM(model.parameters(), base_optimizer)

def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

for input, output in data:
    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.step(closure)
    optimizer.zero_grad()
Source code in pytorch_optimizer/optimizer/sam.py
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
class FriendlySAM(BaseOptimizer):
    """Friendly Sharpness-Aware Minimization.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        base_optimizer (Optimizer): base optimizer.
        rho (float): size of the neighborhood for computing the max loss.
        sigma (float): sigma of FriendlySAM.
        lmbda (float): lambda for FriendlySAM.
        adaptive (bool): element-wise Adaptive SAM.
        perturb_eps (float): eps for perturbation.
        kwargs (Dict): parameters for optimizer.

    Example:
        ```python
        model = YourModel()
        base_optimizer = Ranger21
        optimizer = FriendlySAM(model.parameters(), base_optimizer)

        for input, output in data:
            # first forward-backward pass

            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward pass
            # make sure to do a full forward pass
            loss_function(output, model(input)).backward()
            optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

        model = YourModel()
        base_optimizer = Ranger21
        optimizer = FriendlySAM(model.parameters(), base_optimizer)

        def closure():
            loss = loss_function(output, model(input))
            loss.backward()
            return loss

        for input, output in data:
            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.step(closure)
            optimizer.zero_grad()
        ```

    """

    def __init__(
        self,
        params: ParamsT,
        base_optimizer: OptimizerType,
        rho: float = 0.05,
        sigma: float = 1.0,
        lmbda: float = 0.9,
        adaptive: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')
        self.validate_non_negative(sigma, 'sigma')
        self.validate_non_negative(lmbda, 'lmbda')
        self.validate_non_negative(perturb_eps, 'perturb_eps')

        self.perturb_eps = perturb_eps

        defaults: Defaults = {'rho': rho, 'sigma': sigma, 'lmbda': lmbda, 'adaptive': adaptive}
        defaults.update(kwargs)

        super().__init__(params, defaults)

        self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'FriendlySAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False) -> None:
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                if 'momentum' not in state:
                    state['momentum'] = grad.clone()
                else:
                    momentum = state['momentum']

                    grad.sub_(momentum, alpha=group['sigma'])
                    momentum.lerp_(grad, weight=1.0 - group['lmbda'])

        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)

        for group in self.param_groups:
            scale = group['rho'] / grad_norm

            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                grad = p.grad

                self.state[p]['old_p'] = p.clone()
                self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: Closure = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        with torch.enable_grad():
            closure()

        self.second_step()

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

Fromage

Bases: BaseOptimizer

On the distance between two neural networks and the stability of learning.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
p_bound Optional[float]

Restricts the optimization to a bounded set. For example, a value of 2.0 restricts parameter norms to lie within 2x their initial norms, which helps regularize the model class.

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fromage.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class Fromage(BaseOptimizer):
    """On the distance between two neural networks and the stability of learning.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        p_bound (Optional[float]): Restricts the optimization to a bounded set. For example, a value of 2.0 restricts
            parameter norms to lie within 2x their initial norms, which helps regularize the model class.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self, params: ParamsT, lr: float = 1e-2, p_bound: Optional[float] = None, maximize: bool = False, **kwargs
    ):
        self.validate_learning_rate(lr)

        self.p_bound = p_bound
        self.maximize = maximize

        defaults: Defaults = {'lr': lr}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Fromage'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and self.p_bound is not None:
                state['max'] = p.norm().mul_(self.p_bound)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            pre_factor: float = math.sqrt(1 + group['lr'] ** 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p, grad = self.view_as_real(p, grad)

                p_norm, g_norm = p.norm(), grad.norm()

                if p_norm > 0.0 and g_norm > 0.0:
                    p.add_(grad * (p_norm / g_norm), alpha=-group['lr'])
                else:
                    p.add_(grad, alpha=-group['lr'])

                p.div_(pre_factor)

                if self.p_bound is not None:
                    p_norm = p.norm()
                    if p_norm > state['max']:
                        p.mul_(state['max']).div_(p_norm)

        return loss

FTRL

Bases: BaseOptimizer

Follow The Regularized Leader.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
lr_power float

Controls how the learning rate decreases during training. Use zero for a fixed learning rate.

-0.5
beta float

Beta value as described in the paper.

0.0
lambda_1 float

L1 regularization parameter.

0.0
lambda_2 float

L2 regularization parameter.

0.0
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ftrl.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class FTRL(BaseOptimizer):
    """Follow The Regularized Leader.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        lr_power (float): Controls how the learning rate decreases during training. Use zero for a fixed learning rate.
        beta (float): Beta value as described in the paper.
        lambda_1 (float): L1 regularization parameter.
        lambda_2 (float): L2 regularization parameter.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        lr_power: float = -0.5,
        beta: float = 0.0,
        lambda_1: float = 0.0,
        lambda_2: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(beta, 'beta')
        self.validate_non_positive(lr_power, 'lr_power')
        self.validate_non_negative(lambda_1, 'lambda_1')
        self.validate_non_negative(lambda_2, 'lambda_2')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'lr_power': lr_power, 'beta': beta, 'lambda_1': lambda_1, 'lambda_2': lambda_2}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FTRL'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = torch.zeros_like(p)
                state['n'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, n = state['z'], state['n']

                p, grad, z, n = self.view_as_real(p, grad, z, n)

                grad_p2 = grad.pow(2)

                sigma = (n + grad_p2).pow_(-group['lr_power']).sub_(n.pow(-group['lr_power'])).div_(group['lr'])

                z.add_(grad).sub_(sigma.mul(p))
                n.add_(grad_p2)

                update = z.sign().mul_(group['lambda_1']).sub_(z)
                update.div_((group['beta'] + n.sqrt()).div_(group['lr']).add_(group['lambda_2']))

                p.copy_(update)
                p.masked_fill_(z.abs() < group['lambda_1'], 0.0)

        return loss

GaLore

Bases: BaseOptimizer

AdamW optimizer with GaLore projector.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/galore.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class GaLore(BaseOptimizer):
    """AdamW optimizer with GaLore projector.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'GaLore'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if 'rank' in group and p.dim() == 2:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=group['rank'],
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'])

                if 'exp_avg' not in state:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom

                if 'rank' in group and p.dim() == 2:
                    norm_grad = state['projector'].project_back(norm_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        return loss

get_supported_optimizers(filters=None)

Return list of available optimizer names, sorted alphabetically.

Parameters:

Name Type Description Default
filters Optional[Union[str, List[str]]]

wildcard filter string that works with fmatch. if None, it will return the whole list.

None
Source code in pytorch_optimizer/optimizer/__init__.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def get_supported_optimizers(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
    r"""Return list of available optimizer names, sorted alphabetically.

    Args:
        filters (Optional[Union[str, List[str]]]): wildcard filter string that works with fmatch.
            if None, it will return the whole list.

    """
    if filters is None:
        return sorted(OPTIMIZERS.keys())

    include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]

    filtered_list: Set[str] = set()
    for include_filter in include_filters:
        filtered_list.update(fnmatch.filter(OPTIMIZERS.keys(), include_filter))

    return sorted(filtered_list)

Grams

Bases: BaseOptimizer

Gradient Descent with Adaptive Momentum Scaling.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay.

True
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/grams.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class Grams(BaseOptimizer):
    """Gradient Descent with Adaptive Momentum Scaling.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Grams'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.lerp_(grad, weight=beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                update = (exp_avg / bias_correction1) / (exp_avg_sq / bias_correction2_sq).sqrt_().add_(group['eps'])
                update.abs_().mul_(grad.sign())

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(update, alpha=-group['lr'])

        return loss

Gravity

Bases: BaseOptimizer

a Kinematic Approach on Optimization in Deep Learning.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
alpha float

Alpha controls the V initialization.

0.01
beta float

Beta will be used to compute running average of V.

0.9
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/gravity.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class Gravity(BaseOptimizer):
    """a Kinematic Approach on Optimization in Deep Learning.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        alpha (float): Alpha controls the V initialization.
        beta (float): Beta will be used to compute running average of V.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        alpha: float = 0.01,
        beta: float = 0.9,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'alpha': alpha, 'beta': beta}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Gravity'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['v'] = torch.empty_like(p).normal_(mean=0.0, std=group['alpha'] / group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta_t: float = (group['beta'] * group['step'] + 1) / (group['step'] + 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                v = state['v']

                p, grad, v = self.view_as_real(p, grad, v)

                m = 1.0 / grad.abs().max()
                zeta = grad / (1.0 + (grad / m) ** 2)

                v.mul_(beta_t).add_(zeta, alpha=1.0 - beta_t)

                p.add_(v, alpha=-group['lr'])

        return loss

GrokFastAdamW

Bases: BaseOptimizer

Accelerated Grokking by Amplifying Slow Gradients with AdamW.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.99)
grokfast bool

Whether to use grokfast.

True
grokfast_alpha float

Momentum hyperparameter of the EMA.

0.98
grokfast_lamb float

Amplifying factor hyperparameter of the filter.

2.0
grokfast_after_step int

Warmup step for grokfast.

0
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/grokfast.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class GrokFastAdamW(BaseOptimizer):
    """Accelerated Grokking by Amplifying Slow Gradients with AdamW.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        grokfast (bool): Whether to use grokfast.
        grokfast_alpha (float): Momentum hyperparameter of the EMA.
        grokfast_lamb (float): Amplifying factor hyperparameter of the filter.
        grokfast_after_step (int): Warmup step for grokfast.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-4,
        betas: Betas = (0.9, 0.99),
        grokfast: bool = True,
        grokfast_alpha: float = 0.98,
        grokfast_lamb: float = 2.0,
        grokfast_after_step: int = 0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        normalize_lr: bool = True,
        eps: float = 1e-8,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(grokfast_alpha, 'grokfast_alpha', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.foreach = foreach
        self.maximize = maximize

        if grokfast and normalize_lr:
            lr /= 1.0 + grokfast_lamb

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'grokfast': grokfast,
            'grokfast_alpha': grokfast_alpha,
            'grokfast_lamb': grokfast_lamb,
            'grokfast_after_step': grokfast_after_step,
            'foreach': foreach,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'GrokFastAdamW'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['grokfast'] and group['grokfast_lamb'] > 0.0:
                    state['grok_exp_avg'] = grad.clone()

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        grok_exp_avgs: List[torch.Tensor],
        should_grokfast: bool,
    ) -> None:
        beta1, beta2 = group['betas']

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=group['lr'],
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        if should_grokfast:
            torch._foreach_lerp_(grok_exp_avgs, grads, weight=1.0 - group['grokfast_alpha'])
            torch._foreach_add_(grads, grok_exp_avgs, alpha=group['grokfast_lamb'])

        torch._foreach_lerp_(exp_avgs, grads, weight=1.0 - beta1)
        torch._foreach_mul_(exp_avg_sqs, beta2)
        torch._foreach_addcmul_(exp_avg_sqs, grads, grads, value=1.0 - beta2)

        de_noms = torch._foreach_sqrt(exp_avg_sqs)
        torch._foreach_div_(de_noms, bias_correction2_sq)
        torch._foreach_clamp_min_(de_noms, group['eps'])

        updates = torch._foreach_div(exp_avgs, bias_correction1)
        torch._foreach_div_(updates, de_noms)

        torch._foreach_add_(params, updates, alpha=-group['lr'])

    def _step_per_param(self, group: ParamGroup, should_grokfast: bool) -> None:
        beta1, beta2 = group['betas']

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            exp_avg, exp_avg_sq, grok_exp_avg = (
                state['exp_avg'],
                state['exp_avg_sq'],
                state.get('grok_exp_avg', None),
            )

            p, grad, exp_avg, exp_avg_sq, grok_exp_avg = self.view_as_real(p, grad, exp_avg, exp_avg_sq, grok_exp_avg)

            self.apply_weight_decay(
                p=p,
                grad=grad,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            if should_grokfast:
                grok_exp_avg.lerp_(grad, weight=1.0 - group['grokfast_alpha'])
                grad.add_(grok_exp_avg, alpha=group['grokfast_lamb'])

            exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

            de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).clamp_(min=group['eps'])

            update = exp_avg.div(bias_correction1).div_(de_nom)

            p.add_(update, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            should_grokfast: bool = (
                group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0.0
            )

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group,
                    self.state,
                    state_keys=['exp_avg', 'exp_avg_sq', 'grok_exp_avg'],
                )
                if params:
                    self._step_foreach(
                        group,
                        params,
                        grads,
                        state_dict['exp_avg'],
                        state_dict['exp_avg_sq'],
                        state_dict['grok_exp_avg'],
                        should_grokfast,
                    )
            else:
                self._step_per_param(group, should_grokfast)

        return loss

GSAM

Bases: BaseOptimizer

Surrogate Gap Guided Sharpness-Aware Minimization.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

base optimizer.

required
model Module

model.

required
alpha float

rho alpha.

0.4
rho_scheduler Scheduler

rho scheduler.

required
adaptive bool

element-wise Adaptive SAM.

False
perturb_eps float

epsilon for perturbation.

1e-12
kwargs Dict

parameters for optimizer.

{}
Example
model = YourModel()
base_optimizer = AdamP(model.parameters())
lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

def loss_fn(predictions, targets):
    return F.cross_entropy(predictions, targets)

for inputs, targets in data:
    optimizer.set_closure(loss_fn, inputs, targets)
    predictions, loss = optimizer.step()
    lr_scheduler.step()
    optimizer.update_rho_t()
Source code in pytorch_optimizer/optimizer/sam.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class GSAM(BaseOptimizer):  # pragma: no cover
    """Surrogate Gap Guided Sharpness-Aware Minimization.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        base_optimizer (Optimizer): base optimizer.
        model (nn.Module): model.
        alpha (float): rho alpha.
        rho_scheduler (Scheduler): rho scheduler.
        adaptive (bool): element-wise Adaptive SAM.
        perturb_eps (float): epsilon for perturbation.
        kwargs (Dict): parameters for optimizer.

    Example:
        ```python
        model = YourModel()
        base_optimizer = AdamP(model.parameters())
        lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
        rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
        optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

        def loss_fn(predictions, targets):
            return F.cross_entropy(predictions, targets)

        for inputs, targets in data:
            optimizer.set_closure(loss_fn, inputs, targets)
            predictions, loss = optimizer.step()
            lr_scheduler.step()
            optimizer.update_rho_t()
        ```

    """

    def __init__(
        self,
        params: ParamsT,
        base_optimizer: Optimizer,
        model: nn.Module,
        rho_scheduler,
        alpha: float = 0.4,
        adaptive: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_range(alpha, 'alpha', 0.0, 1.0)

        self.model = model
        self.rho_scheduler = rho_scheduler
        self.alpha = alpha
        self.adaptive = adaptive
        self.perturb_eps = perturb_eps

        self.rho_t: float = 0.0
        self.forward_backward_func: Optional[Callable] = None

        if hasattr(ReduceOp, 'AVG'):
            self.grad_reduce = ReduceOp.AVG
            self.manual_average: bool = False
        else:
            self.grad_reduce = ReduceOp.SUM
            self.manual_average: bool = True

        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

        defaults: Defaults = {'adaptive': adaptive, **kwargs}

        super().__init__(params, defaults)

        self.update_rho_t()

    def __str__(self) -> str:
        return 'GSAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        pass

    @torch.no_grad()
    def update_rho_t(self) -> float:
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho: float):
        grad_norm = self.grad_norm(weight_adaptive=self.adaptive)
        for group in self.param_groups:
            scale = rho / (grad_norm + self.perturb_eps)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['old_g'] = p.grad.clone()

                e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)

                p.add_(e_w)

                self.state[p]['e_w'] = e_w

    @torch.no_grad()
    def un_perturb(self):
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p]:
                    p.sub_(self.state[p]['e_w'])

    @torch.no_grad()
    def gradient_decompose(self, alpha: float = 0.0):
        inner_prod = 0.0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                inner_prod += torch.sum(self.state[p]['old_g'] * p.grad)

        new_grad_norm = self.grad_norm(by=None)
        old_grad_norm = self.grad_norm(by='old_g')

        cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / (
                    new_grad_norm + self.perturb_eps
                )
                p.grad.add_(vertical, alpha=-alpha)

    @torch.no_grad()
    def sync_grad(self):
        if is_initialized():
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    all_reduce(p.grad, op=self.grad_reduce)
                    if self.manual_average:
                        p.grad.div_(float(get_world_size()))

    @torch.no_grad()
    def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor:
        return torch.norm(
            torch.stack(
                [
                    ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2)
                    for group in self.param_groups
                    for p in group['params']
                    if p.grad is not None
                ]
            ),
            p=2,
        )

    def maybe_no_sync(self):
        return self.model.no_sync() if is_initialized() and hasattr(self.model, 'no_sync') else ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs) -> None:
        """Set closure.

        Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
        automatically performs forward and backward passes. This function does not take any arguments,
        and the inputs and targets data should be pre-set in the definition of partial-function.

        Args:
            loss_fn (nn.Module): loss function.
            inputs (torch.Tensor): inputs.
            targets (torch.Tensor): targets.
            kwargs (Dict): keyword arguments.

        """

        def get_grad() -> Tuple[Any, torch.Tensor]:
            self.base_optimizer.zero_grad()

            with torch.enable_grad():
                outputs = self.model(inputs)
                loss = loss_fn(outputs, targets, **kwargs)

            loss.backward()

            return outputs, loss.detach()

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Tuple[Any, torch.Tensor]:
        get_grad = cast(Callable[[], Tuple[Any, torch.Tensor]], closure or self.forward_backward_func)

        with self.maybe_no_sync():
            outputs, loss = get_grad()

            self.perturb_weights(rho=self.rho_t)

            disable_running_stats(self.model)

            get_grad()

            self.gradient_decompose(self.alpha)

            self.un_perturb()

        self.sync_grad()

        self.base_optimizer.step()

        enable_running_stats(self.model)

        return outputs, loss

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

set_closure(loss_fn, inputs, targets, **kwargs)

Set closure.

Create self.forward_backward_func, which is a function such that self.forward_backward_func() automatically performs forward and backward passes. This function does not take any arguments, and the inputs and targets data should be pre-set in the definition of partial-function.

Parameters:

Name Type Description Default
loss_fn Module

loss function.

required
inputs Tensor

inputs.

required
targets Tensor

targets.

required
kwargs Dict

keyword arguments.

{}
Source code in pytorch_optimizer/optimizer/sam.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
@torch.no_grad()
def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs) -> None:
    """Set closure.

    Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
    automatically performs forward and backward passes. This function does not take any arguments,
    and the inputs and targets data should be pre-set in the definition of partial-function.

    Args:
        loss_fn (nn.Module): loss function.
        inputs (torch.Tensor): inputs.
        targets (torch.Tensor): targets.
        kwargs (Dict): keyword arguments.

    """

    def get_grad() -> Tuple[Any, torch.Tensor]:
        self.base_optimizer.zero_grad()

        with torch.enable_grad():
            outputs = self.model(inputs)
            loss = loss_fn(outputs, targets, **kwargs)

        loss.backward()

        return outputs, loss.detach()

    self.forward_backward_func = get_grad

Kate

Bases: BaseOptimizer

Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
delta float

Delta parameter, typically 0.0 or 1e-8.

0.0
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Whether to fix weight decay.

False
eps float

Epsilon value for numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/kate.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class Kate(BaseOptimizer):
    """Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        delta (float): Delta parameter, typically 0.0 or 1e-8.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        eps (float): Epsilon value for numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        delta: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(delta, 'delta', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'delta': delta,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Kate'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['m'] = torch.zeros_like(p)
                state['b'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                m, b = state['m'], state['b']

                p, grad, m, b = self.view_as_real(p, grad, m, b)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                grad_p2 = grad.pow(2)

                b.mul_(b).add_(grad_p2).add_(group['eps'])
                m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_()

                update = m.mul(grad).div_(b)

                p.add_(update, alpha=-group['lr'])

                b.sqrt_()

        return loss

Kron

Bases: BaseOptimizer

PSGD with the Kronecker product pre-conditioner.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor.

0.9
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
pre_conditioner_update_probability Optional[Tuple[Callable, float]]

Probability of updating the pre-conditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.

None
max_size_triangular int

max size for dim's pre-conditioner to be triangular.

8192
min_ndim_triangular int

minimum number of dimensions a layer needs to have triangular pre-conditioners.

2
memory_save_mode Optional[str]

None, 'one_diag', or 'all_diag'. None is default to set all pre-conditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and 'all_diag' sets all pre-conditioners to be diagonal.

None
momentum_into_precondition_update bool

whether to send momentum into pre-conditioner update instead of raw gradients.

True
mu_dtype Optional[dtype]

dtype of the momentum accumulator.

None
precondition_dtype dtype

dtype of the pre-conditioner.

float32
balance_prob float

probability of performing balancing.

0.01
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/psgd.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class Kron(BaseOptimizer):
    """PSGD with the Kronecker product pre-conditioner.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        pre_conditioner_update_probability (Optional[Tuple[Callable, float]]): Probability of updating the
            pre-conditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
        max_size_triangular (int): max size for dim's pre-conditioner to be triangular.
        min_ndim_triangular (int): minimum number of dimensions a layer needs to have triangular pre-conditioners.
        memory_save_mode (Optional[str]): None, 'one_diag', or 'all_diag'. None is default to set all
            pre-conditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and
            'all_diag' sets all pre-conditioners to be diagonal.
        momentum_into_precondition_update (bool): whether to send momentum into pre-conditioner update instead of
            raw gradients.
        mu_dtype (Optional[torch.dtype]): dtype of the momentum accumulator.
        precondition_dtype (torch.dtype): dtype of the pre-conditioner.
        balance_prob (float): probability of performing balancing.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        pre_conditioner_update_probability: Optional[Callable[[int], torch.Tensor]] = None,
        max_size_triangular: int = 8192,
        min_ndim_triangular: int = 2,
        memory_save_mode: Optional[MEMORY_SAVE_MODE_TYPE] = None,
        momentum_into_precondition_update: bool = True,
        mu_dtype: Optional[torch.dtype] = None,
        precondition_dtype: Optional[torch.dtype] = torch.float32,
        balance_prob: float = 0.01,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        if pre_conditioner_update_probability is None:
            pre_conditioner_update_probability = precondition_update_prob_schedule()

        self.balance_prob: float = balance_prob
        self.eps: float = torch.finfo(torch.bfloat16).tiny
        self.prob_step: int = 0
        self.update_counter: int = 0
        self.maximize = maximize

        defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'pre_conditioner_update_probability': pre_conditioner_update_probability,
            'max_size_triangular': max_size_triangular,
            'min_ndim_triangular': min_ndim_triangular,
            'memory_save_mode': memory_save_mode,
            'momentum_into_precondition_update': momentum_into_precondition_update,
            'precondition_lr': 1e-1,
            'precondition_init_scale': 1.0,
            'mu_dtype': mu_dtype,
            'precondition_dtype': precondition_dtype,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Kron'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        update_prob: Union[float, Callable] = self.param_groups[0]['pre_conditioner_update_probability']
        if callable(update_prob):
            update_prob = update_prob(self.prob_step)  # pyright: ignore[reportAssignmentType]

        self.update_counter += 1
        do_update: bool = self.update_counter >= 1 / update_prob  # pyright: ignore[reportOperatorIssue]
        if do_update:
            self.update_counter = 0
        self.prob_step += 1

        balance: bool = np.random.random() < self.balance_prob and do_update

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            bias_correction1: float = self.debias(group['momentum'], group['step'])

            mu_dtype, precondition_dtype = group['mu_dtype'], group['precondition_dtype']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['momentum_buffer'] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
                    state['Q'], state['expressions'] = initialize_q_expressions(
                        p,
                        group['precondition_init_scale'],
                        group['max_size_triangular'],
                        group['min_ndim_triangular'],
                        group['memory_save_mode'],
                        dtype=precondition_dtype,
                    )

                momentum_buffer = state['momentum_buffer']
                momentum_buffer.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])

                if mu_dtype is not None:
                    momentum_buffer = momentum_buffer.to(dtype=mu_dtype, non_blocking=True)

                de_biased_momentum = (momentum_buffer / bias_correction1).to(
                    dtype=precondition_dtype, non_blocking=True
                )

                if grad.dim() > 1 and balance:
                    balance_q(state['Q'])

                if do_update:
                    update_precondition(
                        state['Q'],
                        state['expressions'],
                        torch.randn_like(de_biased_momentum, dtype=precondition_dtype),
                        de_biased_momentum if group['momentum_into_precondition_update'] else grad,
                        group['precondition_lr'],
                        self.eps,
                    )

                precondition_grad = get_precondition_grad(state['Q'], state['expressions'], de_biased_momentum).to(
                    dtype=p.dtype, non_blocking=True
                )

                precondition_grad.mul_(torch.clamp(1.1 / (precondition_grad.square().mean().sqrt() + 1e-6), max=1.0))

                if group['weight_decay'] != 0 and p.dim() >= 2:
                    precondition_grad.add_(p, alpha=group['weight_decay'])

                p.add_(precondition_grad, alpha=-group['lr'])

        return loss

Lamb

Bases: BaseOptimizer

Large Batch Optimization for Deep Learning.

This Lamb implementation is based on the paper v3, which does not use de-biasing.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
rectify bool

Perform the rectified update similar to RAdam.

False
degenerated_to_sgd bool

Degenerate to SGD.

False
n_sma_threshold int

Recommended is 5.

5
grad_averaging bool

Whether to apply (1 - beta2) to gradient when calculating running averages of gradient.

True
max_grad_norm float

Max gradient norm to clip.

1.0
adam bool

Always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.

False
pre_norm bool

Perform pre-normalization of all gradients.

False
eps float

Term added to the denominator to improve numerical stability.

1e-06
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lamb.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
class Lamb(BaseOptimizer):
    """Large Batch Optimization for Deep Learning.

    This Lamb implementation is based on the paper v3, which does not use de-biasing.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        rectify (bool): Perform the rectified update similar to RAdam.
        degenerated_to_sgd (bool): Degenerate to SGD.
        n_sma_threshold (int): Recommended is 5.
        grad_averaging (bool): Whether to apply (1 - beta2) to gradient when calculating running averages of gradient.
        max_grad_norm (float): Max gradient norm to clip.
        adam (bool): Always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.
        pre_norm (bool): Perform pre-normalization of all gradients.
        eps (float): Term added to the denominator to improve numerical stability.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    clamp: float = 10.0

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        degenerated_to_sgd: bool = False,
        n_sma_threshold: int = 5,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        adam: bool = False,
        pre_norm: bool = False,
        eps: float = 1e-6,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.degenerated_to_sgd = degenerated_to_sgd
        self.n_sma_threshold = n_sma_threshold
        self.pre_norm = pre_norm
        self.foreach = foreach
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'grad_averaging': grad_averaging,
            'max_grad_norm': max_grad_norm,
            'adam': adam,
            'eps': eps,
            'foreach': foreach,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lamb'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        """Check if foreach can be used for this group.

        Foreach is disabled when using features that require per-parameter handling:
        - AdaNorm
        - Rectify (has conditional logic per parameter)
        """
        if group.get('foreach') is False:
            return False

        if group.get('adanorm') or group.get('rectify'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        grad_norm: Union[torch.Tensor, float],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        step_size: float,
    ) -> None:
        beta1, beta2 = group['betas']
        eps = group['eps']
        beta3: float = 1.0 - beta1 if group['grad_averaging'] else 1.0

        if self.maximize:
            torch._foreach_neg_(grads)

        if self.pre_norm:
            if isinstance(grad_norm, torch.Tensor):
                grad_norm = grad_norm.item()

            torch._foreach_div_(grads, grad_norm)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=group['lr'],
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        torch._foreach_mul_(exp_avgs, beta1)
        torch._foreach_add_(exp_avgs, grads, alpha=beta3)

        torch._foreach_mul_(exp_avg_sqs, beta2)
        torch._foreach_addcmul_(exp_avg_sqs, grads, grads, value=1.0 - beta2)

        updates = torch._foreach_sqrt(exp_avg_sqs)
        torch._foreach_add_(updates, eps)
        torch._foreach_reciprocal_(updates)
        torch._foreach_mul_(updates, exp_avgs)

        weight_norms = torch._foreach_norm(params)
        torch._foreach_clamp_max_(weight_norms, self.clamp)

        p_norms = torch._foreach_norm(updates)

        for p, update, wn, pn in zip(params, updates, weight_norms, p_norms):
            trust_ratio: float = 1.0
            if wn != 0 and pn != 0:
                trust_ratio = (wn / (pn + eps)).item()

            state = self.state[p]
            state['weight_norm'] = wn
            state['adam_norm'] = pn
            state['trust_ratio'] = trust_ratio

            if group['adam']:
                trust_ratio = 1.0

            p.add_(update, alpha=-step_size * trust_ratio)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    def update(
        self,
        p: torch.Tensor,
        group: ParamGroup,
        grad_norm: Union[torch.Tensor, float],
        n_sma: float,
        step_size: float,
        beta1: float,
        beta2: float,
        beta3: float,
    ) -> None:
        grad = p.grad
        if grad is None:
            return

        if self.pre_norm:
            grad.div_(grad_norm)

        self.maximize_gradient(grad, maximize=self.maximize)

        state = self.state[p]

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

        p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

        s_grad = self.get_adanorm_gradient(
            grad=grad,
            adanorm=group.get('adanorm', False),
            exp_grad_norm=state.get('exp_grad_adanorm', None),
            r=group.get('adanorm_r', None),
        )

        exp_avg.mul_(beta1).add_(s_grad, alpha=beta3)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

        self.apply_weight_decay(
            p=p,
            grad=None,
            lr=group['lr'],
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        de_nom: Optional[torch.Tensor] = None

        if group['rectify']:
            update = p.clone()
            if n_sma >= self.n_sma_threshold:
                de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                update.addcdiv_(exp_avg, de_nom, value=-step_size)
            else:
                update.add_(exp_avg, alpha=-step_size)
        else:
            update = exp_avg / exp_avg_sq.sqrt().add_(group['eps'])

        weight_norm = torch.linalg.norm(p).clamp_(min=0, max=self.clamp)
        p_norm = torch.linalg.norm(update)
        trust_ratio: float = 1.0 if weight_norm == 0 or p_norm == 0 else weight_norm / (p_norm + group['eps'])

        state['weight_norm'] = weight_norm
        state['adam_norm'] = p_norm
        state['trust_ratio'] = trust_ratio

        if group['adam']:
            trust_ratio = 1.0

        if group['rectify']:
            if n_sma >= self.n_sma_threshold:
                p.addcdiv_(exp_avg, de_nom, value=-step_size * trust_ratio)
            else:
                p.add_(exp_avg, alpha=-step_size * trust_ratio)
        else:
            p.add_(update, alpha=-step_size * trust_ratio)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        grad_norm = 1.0
        if self.pre_norm:
            grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            beta3: float = 1.0 - beta1 if group['grad_averaging'] else 1.0
            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_sq']
                )
                if params:
                    self._step_foreach(
                        group, params, grads, grad_norm, state_dict['exp_avg'], state_dict['exp_avg_sq'], step_size
                    )
            else:
                for p in group['params']:
                    self.update(p, group, grad_norm, n_sma, step_size, beta1, beta2, beta3)

        return loss

LaProp

Bases: BaseOptimizer

Separating Momentum and Adaptivity in Adam.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0004
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
centered bool

If True, use the centered variant of Adam.

False
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Epsilon value for numerical stability.

1e-15
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/laprop.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class LaProp(BaseOptimizer):
    """Separating Momentum and Adaptivity in Adam.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        centered (bool): If True, use the centered variant of Adam.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Epsilon value for numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 4e-4,
        betas: Betas = (0.9, 0.999),
        centered: bool = False,
        steps_before_using_centered: int = 10,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-15,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.steps_before_using_centered: int = steps_before_using_centered
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'centered': centered,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'LaProp'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_lr_1'] = 0.0
                state['exp_avg_lr_2'] = 0.0

                if group['centered']:
                    state['exp_mean_avg_beta2'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_mean_avg_beta2 = (
                    state['exp_avg'],
                    state['exp_avg_sq'],
                    state.get('exp_mean_avg_beta2', None),
                )

                p, grad, exp_avg, exp_avg_sq, exp_mean_avg_beta2 = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, exp_mean_avg_beta2
                )

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                state['exp_avg_lr_1'] = state['exp_avg_lr_1'] * beta1 + (1.0 - beta1) * group['lr']
                state['exp_avg_lr_2'] = state['exp_avg_lr_2'] * beta2 + (1.0 - beta2)

                bias_correction1: float = state['exp_avg_lr_1'] / group['lr'] if group['lr'] != 0.0 else 1.0
                step_size: float = 1.0 / bias_correction1

                de_nom = exp_avg_sq
                if group['centered']:
                    exp_mean_avg_beta2.mul_(beta2).add_(grad, alpha=1.0 - beta2)
                    if group['step'] > self.steps_before_using_centered:
                        de_nom -= exp_mean_avg_beta2.pow(2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=(1.0 - beta1) * group['lr'])

                if group.get('cautious'):
                    update = exp_avg.clone()
                    self.apply_cautious(update, grad)
                else:
                    update = exp_avg

                p.add_(update, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

LARS

Bases: BaseOptimizer

Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
weight_decay float

Weight decay (L2 penalty).

0.0
momentum float

Momentum.

0.9
dampening float

Dampening for momentum.

0.0
trust_coefficient float

Trust coefficient.

0.001
nesterov bool

Enables Nesterov momentum.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lars.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class LARS(BaseOptimizer):
    """Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        weight_decay (float): Weight decay (L2 penalty).
        momentum (float): Momentum.
        dampening (float): Dampening for momentum.
        trust_coefficient (float): Trust coefficient.
        nesterov (bool): Enables Nesterov momentum.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        momentum: float = 0.9,
        dampening: float = 0.0,
        trust_coefficient: float = 1e-3,
        nesterov: bool = False,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(dampening, 'dampening', 0.0, 1.0)
        self.validate_non_negative(trust_coefficient, 'trust_coefficient')

        self.foreach = foreach
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'weight_decay': weight_decay,
            'momentum': momentum,
            'dampening': dampening,
            'trust_coefficient': trust_coefficient,
            'nesterov': nesterov,
            'foreach': foreach,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lars'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if group['momentum'] > 0.0:
                state = self.state[p]

                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = grad.clone()

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        """Check if foreach can be used for this group.

        Foreach is disabled when using features that require per-parameter handling:
        - Nesterov momentum (requires per-parameter gradient modification)
        """
        if group.get('foreach') is False:
            return False

        if group.get('nesterov'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        momentum_buffers: List[torch.Tensor],
    ) -> None:
        if self.maximize:
            torch._foreach_neg_(grads)

        masks = [p.ndim > 1 for p in params]
        masked_params = [p for p, m in zip(params, masks) if m]
        masked_grads = [g for g, m in zip(grads, masks) if m]

        if masked_params:
            param_norms = torch._foreach_norm(masked_params)
            grad_norms = torch._foreach_norm(masked_grads)

            trust_ratios = []
            for pn, gn in zip(param_norms, grad_norms):
                one = torch.ones_like(pn)
                trust_ratio = torch.where(
                    pn > 0.0,
                    torch.where(gn > 0.0, (group['trust_coefficient'] * pn / gn), one),
                    one,
                )
                trust_ratios.append(trust_ratio)

            torch._foreach_add_(masked_grads, masked_params, alpha=group['weight_decay'])
            torch._foreach_mul_(masked_grads, trust_ratios)

        if group['momentum'] > 0.0:
            torch._foreach_mul_(momentum_buffers, group['momentum'])
            torch._foreach_add_(momentum_buffers, grads, alpha=1.0 - group['dampening'])
            torch._foreach_copy_(grads, momentum_buffers)

        torch._foreach_add_(params, grads, alpha=-group['lr'])

    def _step_per_param(self, group: ParamGroup) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            if p.ndim > 1:
                param_norm = torch.linalg.norm(p)
                update_norm = torch.linalg.norm(grad)

                one = torch.ones_like(param_norm)

                trust_ratio = torch.where(
                    param_norm > 0.0,
                    torch.where(update_norm > 0.0, (group['trust_coefficient'] * param_norm / update_norm), one),
                    one,
                )

                grad.add_(p, alpha=group['weight_decay'])
                grad.mul_(trust_ratio)

            if group['momentum'] > 0.0:
                mb = state['momentum_buffer']
                mb.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])

                if group['nesterov']:
                    grad.add_(mb, alpha=group['momentum'])
                else:
                    grad.copy_(mb)

            p.add_(grad, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group) and group['momentum'] > 0.0:
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['momentum_buffer']
                )
                if params:
                    self._step_foreach(group, params, grads, state_dict['momentum_buffer'])
            else:
                self._step_per_param(group)

        return loss

Lion

Bases: BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lion.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class Lion(BaseOptimizer):
    """Symbolic Discovery of Optimization Algorithms.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-4,
        betas: Betas = (0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'foreach': foreach,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lion'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        """Check if foreach can be used for this group.

        Foreach is disabled when using features that require per-parameter handling:
        - Gradient centralization
        - AdaNorm
        - Cautious updates
        """
        if group.get('foreach') is False:
            return False

        if group.get('use_gc') or group.get('adanorm') or group.get('cautious'):
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
    ) -> None:
        beta1, beta2 = group['betas']
        lr = group['lr']

        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=lr,
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        updates = torch._foreach_mul(exp_avgs, beta1)
        torch._foreach_add_(updates, grads, alpha=1.0 - beta1)
        torch._foreach_sign_(updates)

        torch._foreach_mul_(exp_avgs, beta2)
        torch._foreach_add_(exp_avgs, grads, alpha=1.0 - beta2)

        torch._foreach_add_(params, updates, alpha=-lr)

    def _step_per_param(self, group: ParamGroup) -> None:
        beta1, beta2 = group['betas']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            exp_avg = state['exp_avg']

            p, grad, exp_avg = self.view_as_real(p, grad, exp_avg)

            if group.get('use_gc'):
                centralize_gradient(grad, gc_conv_only=False)

            self.apply_weight_decay(
                p=p,
                grad=grad,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            s_grad = self.get_adanorm_gradient(
                grad=grad,
                adanorm=group.get('adanorm', False),
                exp_grad_norm=state.get('exp_grad_adanorm', None),
                r=group.get('adanorm_r', None),
            )

            update = exp_avg.clone()

            update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
            exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)

            if group.get('cautious'):
                self.apply_cautious(update, grad)

            p.add_(update, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(group, self.state, state_keys=['exp_avg'])
                if params:
                    self._step_foreach(group, params, grads, state_dict['exp_avg'])
            else:
                self._step_per_param(group)

        return loss

load_ao_optimizer(optimizer)

Load TorchAO optimizer instance.

Source code in pytorch_optimizer/optimizer/__init__.py
299
300
301
302
303
304
305
306
307
308
309
310
def load_ao_optimizer(optimizer: str) -> OptimizerType:  # pragma: no cover
    """Load TorchAO optimizer instance."""
    from torchao.prototype import low_bit_optim  # noqa: PLC0415

    if 'adamw8bit' in optimizer:
        return low_bit_optim.AdamW8bit
    if 'adamw4bit' in optimizer:
        return low_bit_optim.AdamW4bit
    if 'adamwfp8' in optimizer:
        return low_bit_optim.AdamWFp8

    raise NotImplementedError(f'not implemented optimizer {optimizer}')

load_bnb_optimizer(optimizer)

Load bnb optimizer instance.

Source code in pytorch_optimizer/optimizer/__init__.py
278
279
280
281
282
283
284
285
286
def load_bnb_optimizer(optimizer: str) -> OptimizerType:  # pragma: no cover
    """Load bnb optimizer instance."""
    from bitsandbytes import optim  # noqa: PLC0415

    for name, cls_name in BNB_OPTIMIZERS:
        if name in optimizer:
            return getattr(optim, cls_name)

    raise NotImplementedError(f'not implemented optimizer {optimizer}')

load_optimizer(optimizer)

Load optimizers.

Source code in pytorch_optimizer/optimizer/__init__.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def load_optimizer(optimizer: str) -> OptimizerType:
    """Load optimizers."""
    optimizer_name: str = optimizer.lower()

    if optimizer_name.startswith('bnb'):
        if HAS_BNB and torch.cuda.is_available():
            return load_bnb_optimizer(optimizer_name)  # pragma: no cover
        raise ImportError(f'bitsandbytes and CUDA required for the optimizer {optimizer_name}')
    if optimizer_name.startswith('q_galore'):
        if HAS_Q_GALORE and torch.cuda.is_available():
            return load_q_galore_optimizer(optimizer_name)  # pragma: no cover
        raise ImportError(f'bitsandbytes, q-galore-torch, and CUDA required for the optimizer {optimizer_name}')
    if optimizer_name.startswith('torchao'):
        if HAS_TORCHAO and torch.cuda.is_available():
            return load_ao_optimizer(optimizer_name)  # pragma: no cover
        raise ImportError(
            f'torchao required for the optimizer {optimizer_name}. '
            'usage: https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#usage'
        )
    if optimizer_name not in OPTIMIZERS:
        raise NotImplementedError(f'not implemented optimizer : {optimizer_name}')

    return OPTIMIZERS[optimizer_name]

load_q_galore_optimizer(optimizer)

Load Q-GaLore optimizer instance.

Source code in pytorch_optimizer/optimizer/__init__.py
289
290
291
292
293
294
295
296
def load_q_galore_optimizer(optimizer: str) -> OptimizerType:  # pragma: no cover
    """Load Q-GaLore optimizer instance."""
    import q_galore_torch  # noqa: PLC0415

    if 'adamw8bit' in optimizer:
        return q_galore_torch.QGaLoreAdamW8bit

    raise NotImplementedError(f'not implemented optimizer {optimizer}')

LOMO

Bases: BaseOptimizer

Full Parameter Fine-tuning for Large Language Models with Limited Resources.

Reference: https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py Check usage: https://github.com/OpenLMLab/LOMO/blob/main/lomo/src/lomo_trainer.py

Parameters:

Name Type Description Default
model Module

PyTorch model.

required
lr float

Learning rate.

0.001
clip_grad_norm Optional[float]

Gradient norm clipping value.

None
clip_grad_value Optional[float]

Gradient value clipping threshold.

None
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lomo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class LOMO(BaseOptimizer):
    """Full Parameter Fine-tuning for Large Language Models with Limited Resources.

    Reference: https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py
    Check usage: https://github.com/OpenLMLab/LOMO/blob/main/lomo/src/lomo_trainer.py

    Args:
        model (nn.Module): PyTorch model.
        lr (float): Learning rate.
        clip_grad_norm (Optional[float]): Gradient norm clipping value.
        clip_grad_value (Optional[float]): Gradient value clipping threshold.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-3,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(clip_grad_norm, 'clip_grad_norm')
        self.validate_non_negative(clip_grad_value, 'clip_grad_value')

        self.model = model
        self.lr = lr
        self.clip_grad_norm = clip_grad_norm
        self.clip_grad_value = clip_grad_value
        self.maximize = maximize

        self.local_rank: int = int(os.environ.get('LOCAL_RANK', '0'))

        self.gather_norm: bool = False
        self.grad_norms: List[torch.Tensor] = []
        self.clip_coef: Optional[float] = None

        p0: torch.Tensor = next(iter(self.model.parameters()))

        self.grad_func: Callable[[Any], Any] = (
            self.fuse_update_zero3() if hasattr(p0, 'ds_tensor') else self.fuse_update()
        )

        self.loss_scaler: Optional[DynamicLossScaler] = None
        if p0.dtype == torch.float16:
            if clip_grad_norm is None:
                raise ValueError('loss scaling is recommended to be used with grad norm to get better performance.')

            self.loss_scaler = DynamicLossScaler(init_scale=2 ** 16)  # fmt: skip

        for _, p in self.model.named_parameters():
            if p.requires_grad:
                p.register_hook(self.grad_func)

        defaults: Defaults = {'lr': lr}

        super().__init__(self.model.parameters(), defaults)

    def __str__(self) -> str:
        return 'LOMO'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    def fuse_update(self) -> Callable[[Any], Any]:
        @torch.no_grad()
        def func(x: Any) -> Any:
            for _, p in self.model.named_parameters():
                if not p.requires_grad or p.grad is None:
                    continue

                if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    if self.clip_grad_value is not None and self.clip_grad_value > 0.0:
                        grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is not None:
                        grad_fp32.mul_(self.clip_coef)

                    p_fp32 = p.to(torch.float32)
                    p_fp32.add_(grad_fp32, alpha=-self.lr)
                    p.copy_(p_fp32)

            return x

        return func

    def fuse_update_zero3(self) -> Callable[[Any], Any]:  # pragma: no cover
        @torch.no_grad()
        def func(x: torch.Tensor) -> torch.Tensor:
            for _, p in self.model.named_parameters():
                if p.grad is None:
                    continue

                all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

                if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                param_fp32 = p.ds_tensor.to(torch.float32)
                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    one_dim_grad_fp32 = grad_fp32.view(-1)

                    partition_size: int = p.ds_tensor.numel()
                    start: int = partition_size * self.local_rank
                    end: int = min(start + partition_size, grad_fp32.numel())

                    partitioned_grad_fp32 = one_dim_grad_fp32.narrow(0, start, end - start)

                    if self.clip_grad_value is not None:
                        partitioned_grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)

                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
                        partitioned_grad_fp32.mul_(self.clip_coef)

                    partitioned_p = param_fp32.narrow(0, 0, end - start)
                    partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr)

                    p.ds_tensor[: end - start] = partitioned_p  # fmt: skip

            return x

        return func

    def fused_backward(self, loss, lr: float):
        self.lr = lr

        if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is None:
            raise ValueError(
                'clip_grad_norm is not None, but clip_coef is None. '
                'Please call optimizer.grad_norm() before optimizer.fused_backward().'
            )

        if self.loss_scaler:
            loss = loss * self.loss_scaler.loss_scale

        loss.backward()

        self.grad_func(0)

    def grad_norm(self, loss):
        self.gather_norm = True
        self.grad_norms = []

        if self.loss_scaler:
            self.loss_scaler.has_overflow_serial = False
            loss = loss * self.loss_scaler.loss_scale

        loss.backward(retain_graph=True)

        self.grad_func(0)

        if self.loss_scaler and self.loss_scaler.has_overflow_serial:
            self.loss_scaler.update_scale(overflow=True)

            with torch.no_grad():
                for _, p in self.model.named_parameters():
                    p.grad = None
            return

        with torch.no_grad():
            self.grad_norms = torch.stack(self.grad_norms)

            total_norm = torch.norm(self.grad_norms, 2.0)
            self.clip_coef = torch.clamp(float(self.clip_grad_norm) / (total_norm + 1e-6), max=1.0)

        self.gather_norm = False

Lookahead

Bases: BaseOptimizer

k steps forward, 1 step back.

Parameters:

Name Type Description Default
optimizer OptimizerInstanceOrClass

Base optimizer.

required
k int

Number of lookahead steps.

5
alpha float

Linear interpolation factor.

0.5
pullback_momentum str

Change to inner optimizer momentum on interpolation update.

'none'
Source code in pytorch_optimizer/optimizer/lookahead.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class Lookahead(BaseOptimizer):
    """k steps forward, 1 step back.

    Args:
        optimizer (OptimizerInstanceOrClass): Base optimizer.
        k (int): Number of lookahead steps.
        alpha (float): Linear interpolation factor.
        pullback_momentum (str): Change to inner optimizer momentum on interpolation update.

    """

    def __init__(
        self,
        optimizer: OptimizerInstanceOrClass,
        k: int = 5,
        alpha: float = 0.5,
        pullback_momentum: str = 'none',
        **kwargs,
    ) -> None:
        self.validate_positive(k, 'k')
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.alpha = alpha
        self.k = k
        self.pullback_momentum = pullback_momentum

        self.state: State = defaultdict(dict)

        for group in self.param_groups:
            if 'counter' not in group:
                group['counter'] = 0

            for p in group['params']:
                state = self.state[p]
                state['slow_params'] = torch.empty_like(p)
                state['slow_params'].copy_(p)
                if self.pullback_momentum == 'pullback':
                    state['slow_momentum'] = torch.zeros_like(p)

        self.defaults: Defaults = {
            'lookahead_alpha': alpha,
            'lookahead_k': k,
            'lookahead_pullback_momentum': pullback_momentum,
            **self.optimizer.defaults,
        }

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'alpha': self.alpha,
            'k': self.k,
            'pullback_momentum': self.pullback_momentum,
        }

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    def backup_and_load_cache(self) -> None:
        r"""Backup cache parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['backup_params'] = torch.empty_like(p)
                state['backup_params'].copy_(p)
                p.data.copy_(state['slow_params'])

    def clear_and_load_backup(self) -> None:
        r"""Load backup parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['backup_params'])
                del state['backup_params']

    def state_dict(self) -> State:
        lookahead_state: State = {p: dict(param_state) for p, param_state in self.state.items()}
        return {'lookahead_state': lookahead_state, 'base_optimizer': self.optimizer.state_dict()}

    def load_state_dict(self, state: State) -> None:
        r"""Load state."""
        lookahead_state = state['lookahead_state']
        self.state = defaultdict(dict, {p: dict(param_state) for p, param_state in lookahead_state.items()})
        self.optimizer.load_state_dict(state['base_optimizer'])

    @torch.no_grad()
    def update(self, group: Dict):
        for p in group['params']:
            if p.grad is None:
                continue

            state = self.state[p]

            slow = state['slow_params']

            p.mul_(self.alpha).add_(slow, alpha=1.0 - self.alpha)
            slow.copy_(p)

            if 'momentum_buffer' not in self.optimizer.state[p]:
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

            if self.pullback_momentum == 'pullback':
                internal_momentum = self.optimizer.state[p]['momentum_buffer']
                self.optimizer.state[p]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_(
                    state['slow_momentum'], alpha=1.0 - self.alpha
                )
                state['slow_momentum'] = self.optimizer.state[p]['momentum_buffer']
            elif self.pullback_momentum == 'reset':
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = self.optimizer.step(closure)
        for group in self.param_groups:
            group['counter'] += 1
            if group['counter'] >= self.k:
                group['counter'] = 0
                self.update(group)
        return loss

backup_and_load_cache()

Backup cache parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
84
85
86
87
88
89
90
91
def backup_and_load_cache(self) -> None:
    r"""Backup cache parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            state['backup_params'] = torch.empty_like(p)
            state['backup_params'].copy_(p)
            p.data.copy_(state['slow_params'])

clear_and_load_backup()

Load backup parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
93
94
95
96
97
98
99
def clear_and_load_backup(self) -> None:
    r"""Load backup parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            p.data.copy_(state['backup_params'])
            del state['backup_params']

load_state_dict(state)

Load state.

Source code in pytorch_optimizer/optimizer/lookahead.py
105
106
107
108
109
def load_state_dict(self, state: State) -> None:
    r"""Load state."""
    lookahead_state = state['lookahead_state']
    self.state = defaultdict(dict, {p: dict(param_state) for p, param_state in lookahead_state.items()})
    self.optimizer.load_state_dict(state['base_optimizer'])

LookSAM

Bases: BaseOptimizer

An Expeditiously Adaptive Parameter-Free Learner.

Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

Base optimizer.

required
rho float

Size of the neighborhood for computing the max loss.

0.1
k int

Lookahead step.

10
alpha float

Lookahead blending alpha.

0.7
adaptive bool

Element-wise Adaptive SAM.

False
use_gc bool

Perform gradient centralization, GCSAM variant.

False
perturb_eps float

Epsilon for perturbation.

1e-12
kwargs Dict

Additional parameters for optimizer.

{}
Example
model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)

for input, output in data:
    # first forward-backward pass

    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    # make sure to do a full forward pass
    loss_function(output, model(input)).backward()
    optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

model = YourModel()
base_optimizer = Ranger21
optimizer = LookSAM(model.parameters(), base_optimizer)

def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

for input, output in data:
    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.step(closure)
    optimizer.zero_grad()
Source code in pytorch_optimizer/optimizer/sam.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
class LookSAM(BaseOptimizer):
    """An Expeditiously Adaptive Parameter-Free Learner.

    Leave LR set to 1 unless you encounter instability.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        base_optimizer (Optimizer): Base optimizer.
        rho (float): Size of the neighborhood for computing the max loss.
        k (int): Lookahead step.
        alpha (float): Lookahead blending alpha.
        adaptive (bool): Element-wise Adaptive SAM.
        use_gc (bool): Perform gradient centralization, GCSAM variant.
        perturb_eps (float): Epsilon for perturbation.
        kwargs (Dict): Additional parameters for optimizer.

    Example:
        ```python
        model = YourModel()
        base_optimizer = Ranger21
        optimizer = LookSAM(model.parameters(), base_optimizer)

        for input, output in data:
            # first forward-backward pass

            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward pass
            # make sure to do a full forward pass
            loss_function(output, model(input)).backward()
            optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

        model = YourModel()
        base_optimizer = Ranger21
        optimizer = LookSAM(model.parameters(), base_optimizer)

        def closure():
            loss = loss_function(output, model(input))
            loss.backward()
            return loss

        for input, output in data:
            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.step(closure)
            optimizer.zero_grad()
        ```

    """

    def __init__(
        self,
        params: ParamsT,
        base_optimizer: OptimizerType,
        rho: float = 0.1,
        k: int = 10,
        alpha: float = 0.7,
        adaptive: bool = False,
        use_gc: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')
        self.validate_positive(k, 'k')
        self.validate_range(alpha, 'alpha', 0.0, 1.0, '()')
        self.validate_non_negative(perturb_eps, 'perturb_eps')

        self.k = k
        self.alpha = alpha
        self.use_gc = use_gc
        self.perturb_eps = perturb_eps

        defaults: Defaults = {'rho': rho, 'adaptive': adaptive}
        defaults.update(kwargs)

        super().__init__(params, defaults)

        self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'LookSAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        pass

    def get_step(self):
        return (
            self.param_groups[0]['step']
            if 'step' in self.param_groups[0]
            else next(iter(self.base_optimizer.state.values()))['step'] if self.base_optimizer.state else 0
        )

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False) -> None:
        if self.get_step() % self.k != 0:
            return

        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)

        for i, group in enumerate(self.param_groups):
            scale = group['rho'] / grad_norm

            for j, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                grad = p.grad
                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                self.state[p]['old_p'] = p.clone()
                self.state[f'old_grad_p_{i}{j}']['old_grad_p'] = grad.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        step = self.get_step()

        for i, group in enumerate(self.param_groups):
            for j, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                grad = p.grad
                grad_norm = grad.norm(p=2)

                if step % self.k == 0:
                    old_grad_p = self.state[f'old_grad_p_{i}{j}']['old_grad_p']

                    g_grad_norm = old_grad_p / old_grad_p.norm(p=2)
                    g_s_grad_norm = grad / grad_norm

                    self.state[f'gv_{i}{j}']['gv'] = torch.sub(
                        grad, grad_norm * torch.sum(g_grad_norm * g_s_grad_norm) * g_grad_norm
                    )
                else:
                    gv = self.state[f'gv_{i}{j}']['gv']
                    grad.add_(grad_norm / (gv.norm(p=2) + 1e-8) * gv, alpha=self.alpha)

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: Closure = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        with torch.enable_grad():
            closure()

        self.second_step()

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

LoRARite

Bases: BaseOptimizer

Robust Invariant Transformation Equilibration for LoRA optimization.

This optimizer expects LoRA factors in alternating order, such as lora_a_1, lora_b_1, lora_a_2, lora_b_2. Unpaired parameters and pairs with missing gradients are skipped, matching common fine-tuning workflows where only part of the model may receive gradients on a given step.

Parameters:

Name Type Description Default
params ParamsT

Iterable of LoRA parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for first-moment and matrix second-moment estimates.

(0.9, 0.999)
eps float

Term added to the denominator to improve numerical stability.

1e-06
relative_epsilon bool

Scale the root epsilon by the largest matrix second-moment eigenvalue.

False
clip_unmagnified_grad float

Global clipping threshold for unmagnified LoRA gradients. Disabled when 0.

1.0
update_capping float

Per-update RMS capping threshold after preconditioning. Disabled when 0.

0.0
update_skipping float

Skip unmagnified updates whose RMS is above this threshold. Disabled when 0.

1.0
weight_decay float

Coupled weight decay coefficient.

0.0
apply_escape bool

Apply the RITE escape correction when rotating second-moment bases.

False
lora_l_dim int

LoRA rank dimension for left factors.

0
lora_r_dim int

LoRA rank dimension for right factors.

-1
maybe_inf_to_nan bool

Convert infinite update statistics to NaN before threshold checks.

True
balance_param bool

Balance the norms of each LoRA factor pair after applying the update.

False
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lora_rite.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
class LoRARite(BaseOptimizer):
    """Robust Invariant Transformation Equilibration for LoRA optimization.

    This optimizer expects LoRA factors in alternating order, such as ``lora_a_1, lora_b_1, lora_a_2, lora_b_2``.
    Unpaired parameters and pairs with missing gradients are skipped, matching common fine-tuning workflows where only
    part of the model may receive gradients on a given step.

    Args:
        params (ParamsT): Iterable of LoRA parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for first-moment and matrix second-moment estimates.
        eps (float): Term added to the denominator to improve numerical stability.
        relative_epsilon (bool): Scale the root epsilon by the largest matrix second-moment eigenvalue.
        clip_unmagnified_grad (float): Global clipping threshold for unmagnified LoRA gradients. Disabled when 0.
        update_capping (float): Per-update RMS capping threshold after preconditioning. Disabled when 0.
        update_skipping (float): Skip unmagnified updates whose RMS is above this threshold. Disabled when 0.
        weight_decay (float): Coupled weight decay coefficient.
        apply_escape (bool): Apply the RITE escape correction when rotating second-moment bases.
        lora_l_dim (int): LoRA rank dimension for left factors.
        lora_r_dim (int): LoRA rank dimension for right factors.
        maybe_inf_to_nan (bool): Convert infinite update statistics to NaN before threshold checks.
        balance_param (bool): Balance the norms of each LoRA factor pair after applying the update.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        eps: float = 1e-6,
        relative_epsilon: bool = False,
        clip_unmagnified_grad: float = 1.0,
        update_capping: float = 0.0,
        update_skipping: float = 1.0,
        weight_decay: float = 0.0,
        apply_escape: bool = False,
        lora_l_dim: int = 0,
        lora_r_dim: int = -1,
        maybe_inf_to_nan: bool = True,
        balance_param: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(eps, 'eps')
        self.validate_non_negative(clip_unmagnified_grad, 'clip_unmagnified_grad')
        self.validate_non_negative(update_capping, 'update_capping')
        self.validate_non_negative(update_skipping, 'update_skipping')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.helper = LoRARiteHelper(maybe_inf_to_nan=maybe_inf_to_nan)
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'eps': eps,
            'eps_root': eps**2,
            'relative_epsilon': relative_epsilon,
            'clip_unmagnified_grad': clip_unmagnified_grad,
            'update_capping': update_capping,
            'update_skipping': update_skipping,
            'weight_decay': weight_decay,
            'apply_escape': apply_escape,
            'lora_l_dim': lora_l_dim,
            'lora_r_dim': lora_r_dim,
            'balance_param': balance_param,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'LoRARite'

    @staticmethod
    def iter_lora_pairs(group: ParamGroup) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        params = list(group['params'])
        return list(zip(params[::2], params[1::2]))

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for param in group['params']:
            if param.grad is None:
                continue

            if param.grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(param):
                raise NoComplexParameterError(str(self))

    def init_pair_state(
        self, group: ParamGroup, state: Dict[str, Any], param_left: torch.Tensor, param_right: torch.Tensor
    ) -> None:
        if 'step' in state:
            return

        param_left, _ = self.helper.move_lora_dim_to_last(param_left, group['lora_l_dim'])
        param_right, _ = self.helper.move_lora_dim_to_last(param_right, group['lora_r_dim'])

        state['step'] = 0
        state['v_l'] = self.helper.create_preconditioner(param_left)
        state['v_r'] = self.helper.create_preconditioner(param_right)
        state['m_l'] = torch.zeros_like(param_left)
        state['m_r'] = torch.zeros_like(param_right)
        state['basis_l'] = torch.zeros_like(param_left)
        state['basis_r'] = torch.zeros_like(param_right)
        state['escape_l'] = torch.zeros((), dtype=param_left.dtype, device=param_left.device)
        state['escape_r'] = torch.zeros((), dtype=param_right.dtype, device=param_right.device)

    def build_pair_info(
        self,
        group: ParamGroup,
        param_left: torch.Tensor,
        param_right: torch.Tensor,
    ) -> Dict[str, Any]:
        helper = self.helper
        state = self.state[param_left]
        self.init_pair_state(group, state, param_left, param_right)

        param_left_2d, _ = helper.move_lora_dim_to_last(param_left, group['lora_l_dim'])
        param_right_2d, _ = helper.move_lora_dim_to_last(param_right, group['lora_r_dim'])

        basis_left, rotate_left = helper.get_rotation_and_basis(param_left_2d)
        basis_right, rotate_right = helper.get_rotation_and_basis(param_right_2d)
        rotate_inv_left = torch.linalg.pinv(rotate_left)
        rotate_inv_right = torch.linalg.pinv(rotate_right)

        projection_left = basis_right.mT @ state['basis_r']
        projection_right = basis_left.mT @ state['basis_l']

        grad_left = helper.inf_to_nan(param_left.grad.detach())
        grad_right = helper.inf_to_nan(param_right.grad.detach())
        if self.maximize:
            grad_left = grad_left.neg()
            grad_right = grad_right.neg()

        grad_left, _ = helper.move_lora_dim_to_last(grad_left, group['lora_l_dim'])
        grad_right, _ = helper.move_lora_dim_to_last(grad_right, group['lora_r_dim'])

        update_left = helper.get_unmagnified_grad(grad_left, rotate_inv_right)
        update_right = helper.get_unmagnified_grad(grad_right, rotate_inv_left)

        if group['update_skipping'] > 0.0:
            update_left = helper.skip_update(update_left, group['update_skipping'])
            update_right = helper.skip_update(update_right, group['update_skipping'])

        state['basis_l'] = basis_left
        state['basis_r'] = basis_right
        state['rotate_inv_l'] = rotate_inv_left
        state['rotate_inv_r'] = rotate_inv_right
        state['update_l'] = update_left
        state['update_r'] = update_right
        state['projection_l'] = projection_left
        state['projection_r'] = projection_right

        return state

    def apply_pair_update(
        self,
        group: ParamGroup,
        param_left: torch.Tensor,
        param_right: torch.Tensor,
        grad_norm: torch.Tensor,
    ) -> None:
        helper = self.helper
        state = self.state[param_left]
        update_left = state.pop('update_l')
        update_right = state.pop('update_r')
        rotate_inv_left = state.pop('rotate_inv_l')
        rotate_inv_right = state.pop('rotate_inv_r')
        projection_left = state.pop('projection_l')
        projection_right = state.pop('projection_r')
        beta1, beta2 = group['betas']

        param_left_2d, _ = helper.move_lora_dim_to_last(param_left, group['lora_l_dim'])
        param_right_2d, _ = helper.move_lora_dim_to_last(param_right, group['lora_r_dim'])

        if group['clip_unmagnified_grad'] > 0.0 and grad_norm > group['clip_unmagnified_grad']:
            scale = group['clip_unmagnified_grad'] / grad_norm
            update_left = update_left * scale
            update_right = update_right * scale

        second_left = helper.compute_second_moment(update_left)
        second_right = helper.compute_second_moment(update_right)

        transformed_v_left = helper.transform_second_moment_to_new_basis(state['v_l'], projection_left)
        transformed_v_right = helper.transform_second_moment_to_new_basis(state['v_r'], projection_right)

        if group['apply_escape']:
            escape_left = helper.get_unmagnified_rotate_second_escape(transformed_v_left, state['v_l'])
            escape_right = helper.get_unmagnified_rotate_second_escape(transformed_v_right, state['v_r'])
            escape_left = helper.update_second_escape(
                state['step'], torch.zeros_like(escape_left), escape_left + state['escape_l'], beta2
            )
            escape_right = helper.update_second_escape(
                state['step'], torch.zeros_like(escape_right), escape_right + state['escape_r'], beta2
            )
        else:
            escape_left = torch.zeros((), dtype=param_left_2d.dtype, device=param_left_2d.device)
            escape_right = torch.zeros((), dtype=param_right_2d.dtype, device=param_right_2d.device)

        v_left = helper.update_second_moment(state['step'], second_left, transformed_v_left, beta2)
        v_right = helper.update_second_moment(state['step'], second_right, transformed_v_right, beta2)

        update_left = helper.get_preconditioned_update(
            update_left,
            v_left,
            escape_left,
            group['eps'],
            group['eps_root'],
            group['relative_epsilon'],
            group['apply_escape'],
        )
        update_right = helper.get_preconditioned_update(
            update_right,
            v_right,
            escape_right,
            group['eps'],
            group['eps_root'],
            group['relative_epsilon'],
            group['apply_escape'],
        )

        m_left = helper.transform_first_moment_to_new_basis(state['m_l'], projection_left)
        m_right = helper.transform_first_moment_to_new_basis(state['m_r'], projection_right)
        m_left = helper.update_first_moment(state['step'], update_left, m_left, beta1)
        m_right = helper.update_first_moment(state['step'], update_right, m_right, beta1)

        if group['update_capping'] > 0.0:
            m_left = helper.clip_update(m_left, group['update_capping'])
            m_right = helper.clip_update(m_right, group['update_capping'])

        update_left = helper.rotate_update(m_left, rotate_inv_right)
        update_right = helper.rotate_update(m_right, rotate_inv_left)

        if group['weight_decay'] > 0.0:
            update_left = update_left.add(param_left_2d, alpha=group['weight_decay'])
            update_right = update_right.add(param_right_2d, alpha=group['weight_decay'])

        update_left = update_left.mul(-group['lr'])
        update_right = update_right.mul(-group['lr'])

        if group['balance_param']:
            left_norm = torch.linalg.norm(param_left_2d + update_left).add_(1e-6)
            right_norm = torch.linalg.norm(param_right_2d + update_right).add_(1e-6)
            balanced_norm = torch.sqrt(left_norm * right_norm)
            update_left = update_left * (balanced_norm / left_norm) + param_left_2d * (balanced_norm / left_norm - 1.0)
            update_right = update_right * (balanced_norm / right_norm) + param_right_2d * (
                balanced_norm / right_norm - 1.0
            )

        param_left.add_(helper.restore_param_shape(update_left, param_left, group['lora_l_dim']).to(param_left.dtype))
        param_right.add_(
            helper.restore_param_shape(update_right, param_right, group['lora_r_dim']).to(param_right.dtype)
        )

        state['step'] += 1
        state['v_l'] = v_left
        state['v_r'] = v_right
        state['m_l'] = m_left
        state['m_r'] = m_right
        state['escape_l'] = escape_left
        state['escape_r'] = escape_right

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        pair_infos: List[Tuple[ParamGroup, torch.Tensor, torch.Tensor]] = []
        grad_norm_sq: Optional[torch.Tensor] = None

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for param_left, param_right in self.iter_lora_pairs(group):
                if param_left.grad is None or param_right.grad is None:
                    continue

                state = self.build_pair_info(group, param_left, param_right)
                update_left, update_right = state['update_l'], state['update_r']
                if grad_norm_sq is None:
                    grad_norm_sq = update_left.new_zeros(())

                grad_norm_sq.add_(torch.linalg.norm(update_left).pow(2).to(grad_norm_sq.device))
                grad_norm_sq.add_(torch.linalg.norm(update_right).pow(2).to(grad_norm_sq.device))
                pair_infos.append((group, param_left, param_right))

        grad_norm = torch.sqrt(grad_norm_sq) if grad_norm_sq is not None else torch.zeros(())
        for group, param_left, param_right in pair_infos:
            self.apply_pair_update(group, param_left, param_right, grad_norm)

        return loss

MADGRAD

Bases: BaseOptimizer

A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
eps float

Term added to the denominator to improve numerical stability.

1e-06
weight_decay float

Weight decay (L2 penalty). MADGRAD optimizer requires less weight decay than other methods, often as little as zero. On sparse problems both weight_decay and momentum should be set to 0.

0.0
weight_decouple float

Apply AdamW style decoupled weight decay.

False
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/madgrad.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class MADGRAD(BaseOptimizer):
    """A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        eps (float): Term added to the denominator to improve numerical stability.
        weight_decay (float): Weight decay (L2 penalty).
            MADGRAD optimizer requires less weight decay than other methods, often as little as zero.
            On sparse problems both weight_decay and momentum should be set to 0.
        weight_decouple (float): Apply AdamW style decoupled weight decay.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'momentum': momentum,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MADGRAD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if group['momentum'] > 0.0 and grad.is_sparse:
                raise NoSparseGradientError(str(self), note='momentum > 0.0')

            if group['weight_decay'] > 0.0 and not group['weight_decouple'] and grad.is_sparse:
                raise NoSparseGradientError(str(self), note='weight_decay')

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            state['grad_sum_sq'] = torch.zeros_like(p)
            state['s'] = torch.zeros_like(p)

            if group['momentum'] > 0.0:
                state['x0'] = p.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        if 'k' not in self.state:
            self.state['k'] = torch.tensor([0], dtype=torch.long, requires_grad=False)

        for group in self.param_groups:
            if self.state['k'] == 0:
                self.init_group(group)

            weight_decay, momentum, eps = group['weight_decay'], group['momentum'], group['eps']
            lr: float = group['lr'] + eps

            _lambda = lr * math.pow(self.state['k'] + 1, 0.5)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_sum_sq, s = state['grad_sum_sq'], state['s']
                if weight_decay > 0.0 and not group['weight_decouple']:
                    grad.add_(p, alpha=weight_decay)

                if grad.is_sparse:
                    grad = grad.coalesce()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    rms_masked_values = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
                    x0_masked_values = p_masked._values().addcdiv(s_masked._values(), rms_masked_values, value=1)

                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=_lambda)
                    grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)

                    rms_masked_values = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
                    if eps == 0.0:
                        rms_masked_values[rms_masked_values == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)
                    s_masked._values().add_(grad._values(), alpha=_lambda)

                    p_kp1_masked_values = x0_masked_values.addcdiv(s_masked._values(), rms_masked_values, value=-1)

                    p_masked._values().add_(p_kp1_masked_values, alpha=-1)
                    p.data.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0.0:
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.addcdiv(s, rms, value=1)
                    else:
                        x0 = state['x0']

                    grad_sum_sq.addcmul_(grad, grad, value=_lambda)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    if eps == 0.0:
                        rms[rms == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)

                    p_old: Optional[torch.Tensor] = None
                    if weight_decay > 0.0 and group['weight_decouple']:
                        p_old = p.clone()

                    if momentum == 0.0:
                        p.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)
                        p.mul_(momentum).add_(z, alpha=1.0 - momentum)

                    if weight_decay > 0.0 and group['weight_decouple']:
                        p.add_(p_old, alpha=-lr * weight_decay)

        self.state['k'].add_(1)

        return loss

MARS

Bases: BaseOptimizer

Unleashing the Power of Variance Reduction for Training Large Models.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.003
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.95, 0.99)
gamma float

The scaling parameter that controls the strength of gradient correction.

0.025
mars_type MARS_TYPE

Type of MARS. Supported types are adamw, lion, shampoo.

'adamw'
optimize_1d bool

Whether MARS should optimize 1D parameters.

False
lr_1d float

Learning rate for AdamW when optimize_1d is set to False.

0.003
betas_1d Betas

Coefficients for running averages of gradient and squared Hessian for 1D.

(0.9, 0.95)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decay_1d float

Weight decay for 1D parameters.

0.1
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
ams_bound bool

Whether to use the AMSBound variant.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/mars.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class MARS(BaseOptimizer):
    """Unleashing the Power of Variance Reduction for Training Large Models.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        gamma (float): The scaling parameter that controls the strength of gradient correction.
        mars_type (MARS_TYPE): Type of MARS. Supported types are `adamw`, `lion`, `shampoo`.
        optimize_1d (bool): Whether MARS should optimize 1D parameters.
        lr_1d (float): Learning rate for AdamW when optimize_1d is set to False.
        betas_1d (Betas): Coefficients for running averages of gradient and squared Hessian for 1D.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decay_1d (float): Weight decay for 1D parameters.
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        ams_bound (bool): Whether to use the AMSBound variant.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 3e-3,
        betas: Betas = (0.95, 0.99),
        gamma: float = 0.025,
        mars_type: MARS_TYPE = 'adamw',
        optimize_1d: bool = False,
        lr_1d: bool = 3e-3,
        betas_1d: Betas = (0.9, 0.95),
        weight_decay: float = 0.0,
        weight_decay_1d: float = 1e-1,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        cautious: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(lr_1d)
        self.validate_betas(betas)
        self.validate_betas(betas_1d)
        self.validate_options(mars_type, 'mars_type', ['adamw', 'lion', 'shampoo'])
        self.validate_non_negative(gamma, 'gamma')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(weight_decay_1d, 'weight_decay_1d')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'lr_1d': lr_1d,
            'lr_1d_factor': lr_1d / lr,
            'betas': betas,
            'betas_1d': betas_1d,
            'mars_type': mars_type,
            'gamma': gamma,
            'optimize_1d': optimize_1d,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'cautious': cautious,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MARS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['last_grad'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    def optimize_mixed(
        self,
        grad: torch.Tensor,
        last_grad: torch.Tensor,
        exp_avg: torch.Tensor,
        exp_avg_sq: torch.Tensor,
        max_exp_avg_sq: Optional[torch.Tensor],
        betas: Tuple[int, int],
        gamma: float,
        mars_type: MARS_TYPE,
        is_grad_2d: bool,
        step: int,
        ams_bound: bool,
        cautious: bool,
        eps: float,
    ) -> torch.Tensor:
        beta1, beta2 = betas

        c_t = (grad - last_grad).mul_(gamma * (beta1 / (1.0 - beta1))).add_(grad)
        c_t_norm = torch.norm(c_t)
        if c_t_norm > 1.0:
            c_t.div_(c_t_norm)

        exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)

        update = exp_avg.clone()
        if cautious:
            self.apply_cautious(update, grad)

        if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
            exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)

            bias_correction1: float = self.debias(beta1, step)
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

            de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
            de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

            return update.div_(de_nom)

        if mars_type == 'lion':
            return update.sign_()

        factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1)))

        update = update.view(update.size(0), -1)

        return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor).view_as(grad)

    def optimize_1d(
        self,
        grad: torch.Tensor,
        exp_avg: torch.Tensor,
        exp_avg_sq: torch.Tensor,
        max_exp_avg_sq: Optional[torch.Tensor],
        betas: Tuple[int, int],
        step: int,
        ams_bound: bool,
        cautious: bool,
        eps: float,
    ) -> torch.Tensor:
        beta1, beta2 = betas

        bias_correction1: float = self.debias(beta1, step)
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

        exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

        update = exp_avg.clone()

        if cautious:
            self.apply_cautious(update, grad)

        de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
        de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

        return update.div_(de_nom)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, last_grad = state['exp_avg'], state['exp_avg_sq'], state['last_grad']

                p, grad, exp_avg, exp_avg_sq, last_grad = self.view_as_real(p, grad, exp_avg, exp_avg_sq, last_grad)

                is_grad_2d: bool = grad.ndim >= 2
                step_size: float = (
                    group['lr'] if group['optimize_1d'] or is_grad_2d else group['lr'] * group['lr_1d_factor']
                )

                if group['optimize_1d'] or is_grad_2d:
                    update = self.optimize_mixed(
                        grad,
                        last_grad,
                        exp_avg,
                        exp_avg_sq,
                        state.get('max_exp_avg_sq', None),
                        group['betas'],
                        group['gamma'],
                        group['mars_type'],
                        is_grad_2d,
                        group['step'],
                        group['ams_bound'],
                        group.get('cautious'),
                        group['eps'],
                    )
                else:
                    update = self.optimize_1d(
                        grad,
                        exp_avg,
                        exp_avg_sq,
                        state.get('max_exp_avg_sq', None),
                        group['betas_1d'],
                        group['step'],
                        group['ams_bound'],
                        group.get('cautious'),
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(update, alpha=-step_size)

                state['last_grad'] = torch.view_as_complex(grad) if torch.is_complex(state['last_grad']) else grad

        return loss

MSVAG

Bases: BaseOptimizer

Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
beta float

Moving average (momentum) constant (scalar tensor or float value).

0.9
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/msvag.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class MSVAG(BaseOptimizer):
    """Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): Moving average (momentum) constant (scalar tensor or float value).
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        beta: float = 0.9,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'beta': beta}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MSVAG'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @staticmethod
    def get_rho(beta_power: float, beta: float) -> float:
        r"""Get rho."""
        rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
        rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
        return min(rho, 0.9999)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta: float = group['beta']
            beta_power: float = beta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)
                exp_avg_sq.mul_(beta).addcmul_(grad, grad, value=1.0 - beta)

                m = exp_avg.div(beta_power)
                v = exp_avg_sq.div(beta_power)

                rho: float = self.get_rho(beta_power, beta)

                m_p2 = m.pow(2)
                s = (v - m_p2).div_(1.0 - rho)

                factor = m_p2.div(m_p2 + rho * s)
                torch.nan_to_num(factor, nan=0.0, out=factor)
                factor.clamp_(0.0, 1.0)

                p.add_(m * factor, alpha=-group['lr'])

        return loss

get_rho(beta_power, beta) staticmethod

Get rho.

Source code in pytorch_optimizer/optimizer/msvag.py
58
59
60
61
62
63
@staticmethod
def get_rho(beta_power: float, beta: float) -> float:
    r"""Get rho."""
    rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
    rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
    return min(rho, 0.9999)

Muon

Bases: BaseOptimizer

Momentum Orthogonalized by Newton-schulz.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.

Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

Parameters:

Name Type Description Default
params ParamsT

The parameters to be optimized by Muon.

required
lr float

Learning rate.

0.02
momentum float

The momentum used by the internal SGD.

0.95
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
nesterov bool

Whether to use nesterov momentum.

True
ns_steps int

The number of Newton-Schulz iterations to run. (5 is probably always enough)

5
ns_coeffs NewtonSchulzWeights

Newton-Schulz coefficients or preset name.

'original'
use_adjusted_lr bool

Whether to use adjusted learning rate, which is from the Moonlight. Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

False
adamw_lr float

The learning rate for the internal AdamW.

0.0003
adamw_betas tuple

The betas for the internal AdamW.

(0.9, 0.95)
adamw_wd float

The weight decay for the internal AdamW.

0.0
adamw_eps float

The epsilon for the internal AdamW.

1e-10
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Example

from pytorch_optimizer import Muon

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2] hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2] non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

param_groups = [ dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True), dict( params=hidden_gains_biases + non_hidden_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.01, use_muon=False, ), ]

optimizer = Muon(param_groups)

Source code in pytorch_optimizer/optimizer/muon.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class Muon(BaseOptimizer):
    """Momentum Orthogonalized by Newton-schulz.

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
    each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
    update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

    Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
    scalar or vector parameters should be optimized using AdamW.

    Some warnings:
    - We believe this optimizer is unlikely to work well for training with small batch size.
    - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

    Args:
        params (ParamsT): The parameters to be optimized by Muon.
        lr (float): Learning rate.
        momentum (float): The momentum used by the internal SGD.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        nesterov (bool): Whether to use nesterov momentum.
        ns_steps (int): The number of Newton-Schulz iterations to run. (5 is probably always enough)
        ns_coeffs (NewtonSchulzWeights): Newton-Schulz coefficients or preset name.
        use_adjusted_lr (bool): Whether to use adjusted learning rate, which is from the Moonlight.
            Reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
        adamw_lr (float): The learning rate for the internal AdamW.
        adamw_betas (tuple): The betas for the internal AdamW.
        adamw_wd (float): The weight decay for the internal AdamW.
        adamw_eps (float): The epsilon for the internal AdamW.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    Example:
        from pytorch_optimizer import Muon

        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]
        hidden_gains_biases = [p for p in model.body.parameters() if p.ndim < 2]
        non_hidden_params = [*model.head.parameters(), *model.embed.parameters()]

        param_groups = [
            dict(params=hidden_weights, lr=0.02, weight_decay=0.01, use_muon=True),
            dict(
                params=hidden_gains_biases + non_hidden_params,
                lr=3e-4,
                betas=(0.9, 0.95),
                weight_decay=0.01,
                use_muon=False,
            ),
        ]

        optimizer = Muon(param_groups)

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2e-2,
        momentum: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        nesterov: bool = True,
        ns_steps: int = 5,
        ns_coeffs: NewtonSchulzWeights = 'original',
        use_adjusted_lr: bool = False,
        adamw_lr: float = 3e-4,
        adamw_betas: Betas = (0.9, 0.95),
        adamw_wd: float = 0.0,
        adamw_eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(adamw_lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_positive(ns_steps, 'ns_steps')
        self.validate_betas(adamw_betas)
        self.validate_non_negative(adamw_wd, 'adamw_wd')
        self.validate_non_negative(adamw_eps, 'adamw_eps')
        ns_coeffs = get_newton_schulz_weights(ns_coeffs)

        self.maximize = maximize

        for group in params:
            group = cast(ParamGroup, group)
            if 'use_muon' not in group:
                raise ValueError('`use_muon` must be set.')

            if group['use_muon']:
                group['lr'] = group.get('lr', lr)
                group['momentum'] = group.get('momentum', momentum)
                group['nesterov'] = group.get('nesterov', nesterov)
                group['weight_decay'] = group.get('weight_decay', weight_decay)
                group['ns_steps'] = group.get('ns_steps', ns_steps)
                group['ns_coeffs'] = get_newton_schulz_weights(group.get('ns_coeffs', ns_coeffs))
                group['use_adjusted_lr'] = group.get('use_adjusted_lr', use_adjusted_lr)
            else:
                group['lr'] = group.get('lr', adamw_lr)
                group['betas'] = group.get('betas', adamw_betas)
                group['eps'] = group.get('eps', adamw_eps)
                group['weight_decay'] = group.get('weight_decay', adamw_wd)

            group['weight_decouple'] = group.get('weight_decouple', weight_decouple)

        super().__init__(params, kwargs)

    def __str__(self) -> str:
        return 'Muon'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['use_muon']:
                    state['momentum_buffer'] = torch.zeros_like(p)
                else:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                if group['use_muon']:
                    buf = state['momentum_buffer']
                    buf.lerp_(grad, weight=1.0 - group['momentum'])

                    update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf
                    if update.ndim > 2:
                        update = update.view(len(update), -1)

                    update = zero_power_via_newton_schulz_5(
                        update, num_steps=group['ns_steps'], weights=group['ns_coeffs']
                    )

                    if group.get('cautious'):
                        self.apply_cautious(update, grad)

                    lr: float = get_adjusted_lr(group['lr'], p.size(), use_adjusted_lr=group['use_adjusted_lr'])

                    p.add_(update.reshape(p.shape), alpha=-lr)
                else:
                    exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                    beta1, beta2 = group['betas']

                    bias_correction1: float = self.debias(beta1, group['step'])
                    bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

                    exp_avg.lerp_(grad, weight=1.0 - beta1)
                    exp_avg_sq.lerp_(grad.square(), weight=1.0 - beta2)

                    de_nom = exp_avg_sq.sqrt().add_(group['eps']).div_(bias_correction2_sq)

                    p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

Nero

Bases: BaseOptimizer

Learning by Turning: Neural Architecture Aware Optimisation.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
beta float

Coefficients used for computing running averages of gradient and the squared Hessian trace.

0.999
constraints bool

Boolean flag indicating usage of constraints.

True
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/nero.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class Nero(BaseOptimizer):
    """Learning by Turning: Neural Architecture Aware Optimisation.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        constraints (bool): Boolean flag indicating usage of constraints.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 0.01,
        beta: float = 0.999,
        constraints: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'beta': beta, 'constraints': constraints, 'eps': eps}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Nero'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p).add_(group['eps']))

                state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))

                state['scale'] = neuron_norm(p).mean()
                if state['scale'] == 0.0:
                    state['scale'] = 0.01

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            bias_correction: float = self.debias(group['beta'], group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_norm = neuron_norm(grad)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(group['beta']).addcmul_(grad_norm, grad_norm, value=1.0 - group['beta'])

                grad_normed = grad / ((exp_avg_sq / bias_correction).sqrt_().add_(group['eps']))
                torch.nan_to_num(grad_normed, nan=0.0, out=grad_normed)

                p.add_(grad_normed, alpha=-group['lr'] * state['scale'])

                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p).add_(group['eps']))

        return loss

NovoGrad

Bases: BaseOptimizer

Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.95, 0.98)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
grad_averaging bool

Multiply ck (1 - momentum).

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/novograd.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class NovoGrad(BaseOptimizer):
    """Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        grad_averaging (bool): Multiply ck (1 - momentum).
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.95, 0.98),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        grad_averaging: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'grad_averaging': grad_averaging,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'NovoGrad'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_p2 = grad.pow(2).sum()

            if len(state) == 0:
                state['moments'] = grad.div(grad_p2.sqrt().add_(group['eps'])) + group['weight_decay'] * p
                state['grads_ema'] = grad_p2

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grads_ema, moments = state['grads_ema'], state['moments']

                grads_ema.mul_(beta2).add_(grad.pow(2).sum(), alpha=1.0 - beta2)

                de_nom = grads_ema.sqrt().add_(group['eps'])
                grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['grad_averaging']:
                    grad.mul_(1.0 - beta1)

                moments.mul_(beta1).add_(grad)

                p.add_(moments, alpha=-step_size)

        return loss

OrthoGrad

Bases: BaseOptimizer

Grokking at the Edge of Numerical Stability.

A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update.

Parameters:

Name Type Description Default
optimizer OptimizerInstanceOrClass

Base optimizer.

required
Source code in pytorch_optimizer/optimizer/orthograd.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class OrthoGrad(BaseOptimizer):
    """Grokking at the Edge of Numerical Stability.

    A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update.

    Args:
        optimizer (OptimizerInstanceOrClass): Base optimizer.

    """

    def __init__(self, optimizer: OptimizerInstanceOrClass, **kwargs) -> None:
        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}
        self.eps: float = 1e-30

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self.defaults: Defaults = self.optimizer.defaults

    def __str__(self) -> str:
        return 'OrthoGrad'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @property
    def state(self) -> State:
        return self.optimizer.state

    def state_dict(self) -> State:
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: State) -> None:
        self.optimizer.load_state_dict(state_dict)

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def apply_orthogonal_gradients(self, params) -> None:
        for p in params:
            if p.grad is None or p.grad.is_sparse or torch.is_complex(p):
                continue

            w = p.view(-1)
            g = p.grad.view(-1)

            proj = torch.dot(w, g).div_(torch.dot(w, w).add_(self.eps))
            g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
            g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(self.eps)))

            p.grad.copy_(g_ortho_scaled.view_as(p.grad))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        for group in self.param_groups:
            self.apply_orthogonal_gradients(group['params'])
        return self.optimizer.step(closure)

PAdam

Bases: BaseOptimizer

Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.1
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
partial float

Partially adaptive parameter.

0.25
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/padam.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class PAdam(BaseOptimizer):
    """Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        partial (float): Partially adaptive parameter.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        betas: Betas = (0.9, 0.999),
        partial: float = 0.25,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(partial, 'partial', 0.0, 1.0, range_type='(]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'partial': partial,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom ** (group['partial'] * 2), value=-step_size)

        return loss

PCGrad

Bases: BaseOptimizer

Gradient Surgery for Multi-Task Learning.

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer instance.

required
reduction str

Reduction method for gradients.

'mean'
Source code in pytorch_optimizer/optimizer/pcgrad.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class PCGrad(BaseOptimizer):
    """Gradient Surgery for Multi-Task Learning.

    Args:
        optimizer (Optimizer): Optimizer instance.
        reduction (str): Reduction method for gradients.

    """

    def __init__(self, optimizer: Optimizer, reduction: str = 'mean'):
        self.validate_options(reduction, 'reduction', ['mean', 'sum'])

        self.optimizer = optimizer
        self.reduction = reduction

    @torch.no_grad()
    def init_group(self):
        self.zero_grad()

    def zero_grad(self):
        return self.optimizer.zero_grad(set_to_none=True)

    def step(self):
        return self.optimizer.step()

    def set_grad(self, grads: List[torch.Tensor]) -> None:
        idx: int = 0
        for group in self.optimizer.param_groups:
            for p in group['params']:
                p.grad = grads[idx]
                idx += 1

    def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
        """Get the gradient of the parameters of the network with specific objective."""
        grad, shape, has_grad = [], [], []
        for group in self.optimizer.param_groups:
            for p in group['params']:
                if p.grad is None:
                    shape.append(p.shape)
                    grad.append(torch.zeros_like(p, device=p.device))
                    has_grad.append(torch.zeros_like(p, device=p.device))
                    continue

                shape.append(p.grad.shape)
                grad.append(p.grad.clone())
                has_grad.append(torch.ones_like(p, device=p.device))

        return grad, shape, has_grad

    def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
        """Pack the gradient of the parameters of the network for each objective.

        Args:
            objectives (Iterable[nn.Module]): A list of objectives.

        """
        grads, shapes, has_grads = [], [], []
        for objective in objectives:
            self.optimizer.zero_grad(set_to_none=True)
            objective.backward(retain_graph=True)

            grad, shape, has_grad = self.retrieve_grad()

            grads.append(flatten_grad(grad))
            has_grads.append(flatten_grad(has_grad))
            shapes.append(shape)

        return grads, shapes, has_grads

    def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
        """Project conflicting.

        Args:
            grads (List[torch.Tensor]): A list of the gradient of the parameters.
            has_grads (List[torch.Tensor]): A list of masks representing whether the parameter has gradient.

        """
        shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

        pc_grad: List[torch.Tensor] = deepcopy(grads)
        for i, g_i in enumerate(pc_grad):
            random.shuffle(grads)
            for g_j in grads:
                g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

        merged_grad: torch.Tensor = torch.zeros_like(grads[0])

        shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
        if self.reduction == 'mean':
            merged_grad[shared] = shared_pc_gradients.mean(dim=0)
        else:
            merged_grad[shared] = shared_pc_gradients.sum(dim=0)

        merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

        return merged_grad

    def pc_backward(self, objectives: Iterable[nn.Module]) -> None:
        """Calculate the gradient of the parameters.

        Args:
            objectives (Iterable[nn.Module]): A list of objectives.

        """
        grads, shapes, has_grads = self.pack_grad(objectives)

        pc_grad = self.project_conflicting(grads, has_grads)
        pc_grad = un_flatten_grad(pc_grad, shapes[0])

        self.set_grad(pc_grad)

pack_grad(objectives)

Pack the gradient of the parameters of the network for each objective.

Parameters:

Name Type Description Default
objectives Iterable[Module]

A list of objectives.

required
Source code in pytorch_optimizer/optimizer/pcgrad.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
    """Pack the gradient of the parameters of the network for each objective.

    Args:
        objectives (Iterable[nn.Module]): A list of objectives.

    """
    grads, shapes, has_grads = [], [], []
    for objective in objectives:
        self.optimizer.zero_grad(set_to_none=True)
        objective.backward(retain_graph=True)

        grad, shape, has_grad = self.retrieve_grad()

        grads.append(flatten_grad(grad))
        has_grads.append(flatten_grad(has_grad))
        shapes.append(shape)

    return grads, shapes, has_grads

pc_backward(objectives)

Calculate the gradient of the parameters.

Parameters:

Name Type Description Default
objectives Iterable[Module]

A list of objectives.

required
Source code in pytorch_optimizer/optimizer/pcgrad.py
128
129
130
131
132
133
134
135
136
137
138
139
140
def pc_backward(self, objectives: Iterable[nn.Module]) -> None:
    """Calculate the gradient of the parameters.

    Args:
        objectives (Iterable[nn.Module]): A list of objectives.

    """
    grads, shapes, has_grads = self.pack_grad(objectives)

    pc_grad = self.project_conflicting(grads, has_grads)
    pc_grad = un_flatten_grad(pc_grad, shapes[0])

    self.set_grad(pc_grad)

project_conflicting(grads, has_grads)

Project conflicting.

Parameters:

Name Type Description Default
grads List[Tensor]

A list of the gradient of the parameters.

required
has_grads List[Tensor]

A list of masks representing whether the parameter has gradient.

required
Source code in pytorch_optimizer/optimizer/pcgrad.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
    """Project conflicting.

    Args:
        grads (List[torch.Tensor]): A list of the gradient of the parameters.
        has_grads (List[torch.Tensor]): A list of masks representing whether the parameter has gradient.

    """
    shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

    pc_grad: List[torch.Tensor] = deepcopy(grads)
    for i, g_i in enumerate(pc_grad):
        random.shuffle(grads)
        for g_j in grads:
            g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
            if g_i_g_j < 0:
                pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

    merged_grad: torch.Tensor = torch.zeros_like(grads[0])

    shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
    if self.reduction == 'mean':
        merged_grad[shared] = shared_pc_gradients.mean(dim=0)
    else:
        merged_grad[shared] = shared_pc_gradients.sum(dim=0)

    merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

    return merged_grad

retrieve_grad()

Get the gradient of the parameters of the network with specific objective.

Source code in pytorch_optimizer/optimizer/pcgrad.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
    """Get the gradient of the parameters of the network with specific objective."""
    grad, shape, has_grad = [], [], []
    for group in self.optimizer.param_groups:
        for p in group['params']:
            if p.grad is None:
                shape.append(p.shape)
                grad.append(torch.zeros_like(p, device=p.device))
                has_grad.append(torch.zeros_like(p, device=p.device))
                continue

            shape.append(p.grad.shape)
            grad.append(p.grad.clone())
            has_grad.append(torch.ones_like(p, device=p.device))

    return grad, shape, has_grad

PID

Bases: BaseOptimizer

A PID Controller Approach for Stochastic Optimization of Deep Networks.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
momentum float

Momentum factor.

0.0
dampening float

Dampening for momentum.

0.0
derivative float

D part of the PID.

10.0
integral float

I part of the PID.

5.0
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/pid.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class PID(BaseOptimizer):
    """A PID Controller Approach for Stochastic Optimization of Deep Networks.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Momentum factor.
        dampening (float): Dampening for momentum.
        derivative (float): D part of the PID.
        integral (float): I part of the PID.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        derivative: float = 10.0,
        integral: float = 5.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(derivative, 'derivative')
        self.validate_non_negative(integral, 'integral')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'dampening': dampening,
            'derivative': derivative,
            'integral': integral,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PID'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and group['momentum'] > 0.0:
                state['grad_buffer'] = torch.zeros_like(p)
                state['i_buffer'] = torch.zeros_like(p)
                state['d_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                g_buf, i_buf, d_buf = (
                    state.get('grad_buffer', None),
                    state.get('i_buffer', None),
                    state.get('d_buffer', None),
                )

                p, grad, g_buf, i_buf, d_buf = self.view_as_real(p, grad, g_buf, i_buf, d_buf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['momentum'] > 0.0:
                    i_buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])
                    d_buf.mul_(group['momentum'])

                    if group['step'] > 1:
                        d_buf.add_(grad - g_buf, alpha=1.0 - group['momentum'])
                        g_buf.copy_(grad)

                    grad.add_(i_buf, alpha=group['integral']).add_(d_buf, alpha=group['derivative'])

                p.add_(grad, alpha=-group['lr'])

        return loss

PNM

Bases: BaseOptimizer

Positive-Negative Momentum.

Parameters:

Name Type Description Default
params ParamsT

Iterable of the parameters to optimize.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 1.0)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Use weight decoupling, as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/pnm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class PNM(BaseOptimizer):
    """Positive-Negative Momentum.

    Args:
        params (ParamsT): Iterable of the parameters to optimize.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Use weight decoupling, as in AdamW.
        fixed_decay (bool): Fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas, beta_range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PNM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['pos_momentum'] = torch.zeros_like(p)
                state['neg_momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            beta1_p2: float = beta1 ** 2  # fmt: skip
            noise_norm: float = math.sqrt((1 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group['step'] % 2 == 1:
                    pos_momentum, neg_momentum = state['pos_momentum'], state['neg_momentum']
                else:
                    neg_momentum, pos_momentum = state['pos_momentum'], state['neg_momentum']

                p, grad, pos_momentum, neg_momentum = self.view_as_real(p, grad, pos_momentum, neg_momentum)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                pos_momentum.mul_(beta1_p2).add_(grad, alpha=1.0 - beta1_p2)

                delta_p = pos_momentum.mul(1.0 + beta2).add_(neg_momentum, alpha=-beta2).mul_(1.0 / noise_norm)

                p.add_(delta_p, alpha=-group['lr'])

        return loss

Prodigy

Bases: BaseOptimizer

An Expeditiously Adaptive Parameter-Free Learner.

Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

1.0
betas Betas

betas.

(0.9, 0.999)
beta3 float

coefficients for computing the Prodigy step-size using running averages. If set to None, uses the value of square root of beta2.

None
d0 float

initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
d_coef float

Coefficient in the expression for the estimate of d.

1.0
growth_rate float

prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

use AdamW style weight decay.

True
fixed_decay bool

fix weight decay.

False
bias_correction bool

turn on Adam's bias correction.

False
safeguard_warmup bool

remove lr from the denominator of D estimate to avoid issues during warm-up stage.

False
eps float

term added to the denominator to improve numerical stability. when eps is None, use atan2 rather than epsilon and division for parameter updates.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/prodigy.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
class Prodigy(BaseOptimizer):
    """An Expeditiously Adaptive Parameter-Free Learner.

    Leave LR set to 1 unless you encounter instability.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): betas.
        beta3 (float): coefficients for computing the Prodigy step-size using running averages. If set to None,
            uses the value of square root of beta2.
        d0 (float): initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
        d_coef (float): Coefficient in the expression for the estimate of d.
        growth_rate (float): prevent the D estimate from growing faster than this multiplicative rate.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): use AdamW style weight decay.
        fixed_decay (bool): fix weight decay.
        bias_correction (bool): turn on Adam's bias correction.
        safeguard_warmup (bool): remove lr from the denominator of D estimate to avoid issues during warm-up stage.
        eps (float): term added to the denominator to improve numerical stability. when eps is None, use atan2 rather
            than epsilon and division for parameter updates.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        betas: Betas = (0.9, 0.999),
        beta3: Optional[float] = None,
        d0: float = 1e-6,
        d_coef: float = 1.0,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        safeguard_warmup: bool = False,
        eps: Optional[float] = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas((*betas, beta3))
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'beta3': beta3,
            'd': d0,
            'd0': d0,
            'd_max': d0,
            'd_coef': d_coef,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'safeguard_warmup': safeguard_warmup,
            'step': 1,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Prodigy'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(p)
                state['p0'] = p.clone()
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d_de_nom = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta3: float = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2)

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
        bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0

        d, d0 = group['d'], group['d0']
        d_lr: float = d * group['lr'] / bias_correction

        if 'd_numerator' not in group:
            group['d_numerator'] = torch.tensor([0.0], device=device)
        elif group['d_numerator'].device != device:
            group['d_numerator'] = group['d_numerator'].to(device)  # pragma: no cover

        d_numerator = group['d_numerator']
        d_numerator.mul_(beta3)

        for group in self.param_groups:
            if group['step'] == 1:
                self.init_group(group)

            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p0, exp_avg, exp_avg_sq = state['p0'], state['exp_avg'], state['exp_avg_sq']

                d_numerator.add_(torch.dot(grad.flatten(), (p0 - p).flatten()), alpha=(d / d0) * d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=(d / d0) * (d if group['safeguard_warmup'] else d_lr))

                d_de_nom.add_(s.abs().sum())

        if d_de_nom == 0:
            return loss

        d_hat = (group['d_coef'] * d_numerator / d_de_nom).item()
        if d == group['d0']:
            d = max(d, d_hat)

        d_max = max(group['d_max'], d_hat)
        d = min(d_max, d * group['growth_rate'])

        for group in self.param_groups:
            group['step'] += 1

            group['d_numerator'] = d_numerator
            group['d_de_nom'] = d_de_nom
            group['d'] = d
            group['d_max'] = d_max
            group['d_hat'] = d_hat

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                self.apply_weight_decay(
                    p,
                    p.grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                de_nom = exp_avg_sq.sqrt()

                if group['eps'] is not None:
                    de_nom.add_(d * group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-d_lr)
                else:
                    update = exp_avg.clone().atan2_(de_nom)
                    p.add_(update, alpha=-d_lr)

        return loss

QHAdam

Bases: BaseOptimizer

Quasi-hyperbolic momentum and Adam for deep learning.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
betas Betas

coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
nus Tuple[float, float]

immediate discount factors used to estimate the gradient and its square.

(1.0, 1.0)
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

fix weight decay.

False
eps float

term added to the denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/qhadam.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class QHAdam(BaseOptimizer):
    """Quasi-hyperbolic momentum and Adam for deep learning.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared Hessian trace.
        nus (Tuple[float, float]): immediate discount factors used to estimate the gradient and its square.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        nus: Tuple[float, float] = (1.0, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_nus(nus)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'nus': nus,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['beta1_weight'] = torch.zeros((1,), dtype=torch.float32, device=grad.device)
                state['beta2_weight'] = torch.zeros((1,), dtype=torch.float32, device=grad.device)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']
            nu1, nu2 = group['nus']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                beta1_weight, beta2_weight = state['beta1_weight'], state['beta2_weight']
                beta1_weight.mul_(beta1).add_(1.0)
                beta2_weight.mul_(beta2).add_(1.0)

                beta1_adj = 1.0 - (1.0 / beta1_weight)
                beta2_adj = 1.0 - (1.0 / beta2_weight)

                grad_p2 = grad.pow(2)

                exp_avg.mul_(beta1_adj).add_((1.0 - beta1_adj) * grad)
                exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj * grad_p2)

                avg_grad = exp_avg.mul(nu1)
                if nu1 != 1.0:
                    avg_grad.add_(grad, alpha=1.0 - nu1)

                avg_grad_rms = exp_avg_sq.mul(nu2)
                if nu2 != 1.0:
                    avg_grad_rms.add_(grad_p2, alpha=1.0 - nu2)

                avg_grad_rms.sqrt_().add_(group['eps'])

                p.addcdiv_(avg_grad, avg_grad_rms, value=-group['lr'])

        return loss

QHM

Bases: BaseOptimizer

Quasi-hyperbolic momentum (QHM) optimization algorithm.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor.

0.0
nu float

immediate discount factor used to estimate the gradient and its square.

1.0
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

fix weight decay.

False
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/qhm.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class QHM(BaseOptimizer):
    """Quasi-hyperbolic momentum (QHM) optimization algorithm.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor.
        nu (float): immediate discount factor used to estimate the gradient and its square.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.0,
        nu: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_nus(nu)

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'nu': nu,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                buf = state['momentum_buffer']

                p, grad, buf = self.view_as_real(p, grad, buf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])

                p.add_(buf, alpha=-group['lr'] * group['nu'])
                p.add_(grad, alpha=-group['lr'] * (1.0 - group['nu']))

        return loss

RACS

Bases: BaseOptimizer

Row and Column Scaled SGD.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
beta float

momentum factor.

0.9
alpha float

scaler.

0.05
gamma float

limiter threshold.

1.01
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
eps float

term added to the denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/racs.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class RACS(BaseOptimizer):
    """Row and Column Scaled SGD.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        beta (float): momentum factor.
        alpha (float): scaler.
        gamma (float): limiter threshold.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        beta: float = 0.9,
        alpha: float = 0.05,
        gamma: float = 1.01,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_positive(gamma, 'gamma')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'beta': beta,
            'alpha': alpha,
            'gamma': gamma,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'RACS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta = group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if grad.ndim < 2:
                    grad = grad.reshape(len(grad), 1)
                elif grad.ndim > 2:
                    grad = grad.reshape(len(grad), -1)

                if len(state) == 0:
                    state['s'] = torch.zeros(grad.size(0), dtype=grad.dtype, device=grad.device)
                    state['q'] = torch.ones(grad.size(1), dtype=grad.dtype, device=grad.device)
                    state['theta'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s, q = state['s'], state['q']

                grad_p2 = grad.pow(2)
                s.mul_(beta).add_(grad_p2.mean(dim=1), alpha=1.0 - beta)
                q.mul_(beta).add_(grad_p2.mean(dim=0), alpha=1.0 - beta)

                s_sq = s.add(group['eps']).sqrt_().unsqueeze(1)
                q_sq = q.add(group['eps']).sqrt_().unsqueeze(0)

                grad_hat = grad / (s_sq * q_sq)

                grad_hat_norm = torch.norm(grad_hat)
                threshold = (
                    group['gamma'] / max(grad_hat_norm / (state['theta'] + group['eps']), group['gamma'])
                    if group['step'] > 1
                    else 1.0
                )
                state['theta'] = grad_hat_norm.mul_(threshold)

                p.add_(grad_hat.view_as(p), alpha=-group['lr'] * group['alpha'] * threshold)

        return loss

RAdam

Bases: BaseOptimizer

Rectified Adam.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
n_sma_threshold int

recommended is 5.

5
degenerated_to_sgd float

degenerated to SGD.

False
eps float

term added to the denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/radam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class RAdam(BaseOptimizer):
    """Rectified Adam.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        n_sma_threshold (int): recommended is 5.
        degenerated_to_sgd (float): degenerated to SGD.
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'RAdam'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if step_size > 0 or n_sma >= self.n_sma_threshold:
                    self.apply_weight_decay(
                        p=p,
                        grad=None,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

Ranger

Bases: BaseOptimizer

A synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.999)
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
n_sma_threshold int

recommended is 5.

5
degenerated_to_sgd bool

perform SGD update when variance of gradient is high.

False
use_gc bool

use Gradient Centralization (both convolution & fc layers).

True
gc_conv_only bool

use Gradient Centralization (only convolution layer).

False
eps float

term added to the denominator to improve numerical stability.

1e-05
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ranger.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class Ranger(BaseOptimizer):
    """A synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        n_sma_threshold (int): recommended is 5.
        degenerated_to_sgd (bool): perform SGD update when variance of gradient is high.
        use_gc (bool): use Gradient Centralization (both convolution & fc layers).
        gc_conv_only (bool): use Gradient Centralization (only convolution layer).
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.95, 0.999),
        alpha: float = 0.5,
        k: int = 6,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        use_gc: bool = True,
        gc_conv_only: bool = False,
        eps: float = 1e-5,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(alpha, 'alpha', 0.0, 1.0, range_type='[]')
        self.validate_positive(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.use_gc = use_gc
        self.gc_gradient_threshold: int = 3 if gc_conv_only else 1
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'k': k,
            'step_counter': 0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Ranger'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['slow_buffer'] = p.clone()

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, slow_buffer = state['exp_avg'], state['exp_avg_sq'], state['slow_buffer']

                p, grad, exp_avg, exp_avg_sq, slow_buffer = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, slow_buffer
                )

                if self.use_gc and grad.dim() > self.gc_gradient_threshold:
                    centralize_gradient(grad, gc_conv_only=False)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                else:
                    p.add_(exp_avg, alpha=-step_size)

                if group['step'] % group['k'] == 0:
                    slow_buffer.lerp_(p, weight=group['alpha'])
                    p.copy_(slow_buffer)

        return loss

Ranger21

Bases: BaseOptimizer

Integrating the latest deep learning components into a single optimizer.

Here's the components * uses the AdamW optimizer as its core (or, optionally, MadGrad) * Adaptive gradient clipping * Gradient centralization * Positive-Negative momentum * Norm loss * Stable weight decay * Linear learning rate warm-up * Explore-exploit learning rate schedule * Lookahead * Softplus transformation * Gradient Normalization * Corrects the denominator (AdamD).

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
num_iterations int

number of the total training steps. Ranger21 optimizer schedules the learning rate with its own recipes.

required
lr float

learning rate.

0.001
beta0 float

Manages the amplitude of the noise introduced by positive negative momentum while 0.9 is a recommended default value, you can use -0.5 to minimize the noise.

0.9
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
use_softplus bool

use softplus to smooth.

True
beta_softplus float

beta.

50.0
disable_lr_scheduler bool

whether to disable learning rate schedule.

False
num_warm_up_iterations Optional[int]

number of warm-up iterations. Ranger21 performs linear learning rate warmup.

None
num_warm_down_iterations Optional[int]

number of warm-down iterations. Ranger21 performs Explore-exploit learning rate scheduling.

None
agc_clipping_value float
0.01
agc_eps float

eps for AGC

0.001
centralize_gradients bool

use GC both convolution & fc layers.

True
normalize_gradients bool

use gradient normalization.

True
lookahead_merge_time int

merge time.

5
lookahead_blending_alpha float

blending alpha.

0.5
weight_decay float

weight decay (L2 penalty).

0.0001
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

fix weight decay.

False
norm_loss_factor float

norm loss factor.

0.0001
eps float

term added to the denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ranger21.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class Ranger21(BaseOptimizer):
    """Integrating the latest deep learning components into a single optimizer.

    Here's the components
        * uses the AdamW optimizer as its core (or, optionally, MadGrad)
        * Adaptive gradient clipping
        * Gradient centralization
        * Positive-Negative momentum
        * Norm loss
        * Stable weight decay
        * Linear learning rate warm-up
        * Explore-exploit learning rate schedule
        * Lookahead
        * Softplus transformation
        * Gradient Normalization
        * Corrects the denominator (AdamD).

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        num_iterations (int): number of the total training steps. Ranger21 optimizer schedules the learning rate
            with its own recipes.
        lr (float): learning rate.
        beta0 (float): Manages the amplitude of the noise introduced by positive negative momentum
            while 0.9 is a recommended default value, you can use -0.5 to minimize the noise.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        use_softplus (bool): use softplus to smooth.
        beta_softplus (float): beta.
        disable_lr_scheduler (bool): whether to disable learning rate schedule.
        num_warm_up_iterations (Optional[int]): number of warm-up iterations. Ranger21 performs linear learning rate
            warmup.
        num_warm_down_iterations (Optional[int]): number of warm-down iterations. Ranger21 performs Explore-exploit
            learning rate scheduling.
        agc_clipping_value (float):
        agc_eps (float): eps for AGC
        centralize_gradients (bool): use GC both convolution & fc layers.
        normalize_gradients (bool): use gradient normalization.
        lookahead_merge_time (int): merge time.
        lookahead_blending_alpha (float): blending alpha.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        norm_loss_factor (float): norm loss factor.
        eps (float): term added to the denominator to improve numerical stability.
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(  # pylint: disable=R0913
        self,
        params: ParamsT,
        num_iterations: int,
        lr: float = 1e-3,
        beta0: float = 0.9,
        betas: Betas = (0.9, 0.999),
        use_softplus: bool = True,
        beta_softplus: float = 50.0,
        disable_lr_scheduler: bool = False,
        num_warm_up_iterations: Optional[int] = None,
        num_warm_down_iterations: Optional[int] = None,
        warm_down_min_lr: float = 3e-5,
        agc_clipping_value: float = 1e-2,
        agc_eps: float = 1e-3,
        centralize_gradients: bool = True,
        normalize_gradients: bool = True,
        lookahead_merge_time: int = 5,
        lookahead_blending_alpha: float = 0.5,
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        norm_loss_factor: float = 1e-4,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(warm_down_min_lr)
        self.validate_betas(betas)
        self.validate_range(beta0, 'beta0', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(agc_clipping_value, 'agc_clipping_value')
        self.validate_non_negative(eps, 'eps')
        self.validate_non_negative(agc_eps, 'agc_eps')

        self.min_lr = warm_down_min_lr
        self.use_softplus = use_softplus
        self.beta_softplus = beta_softplus
        self.disable_lr_scheduler = disable_lr_scheduler
        self.agc_clipping_value = agc_clipping_value
        self.agc_eps = agc_eps
        self.centralize_gradients = centralize_gradients
        self.normalize_gradients = normalize_gradients
        self.lookahead_merge_time = lookahead_merge_time
        self.lookahead_blending_alpha = lookahead_blending_alpha
        self.norm_loss_factor = norm_loss_factor
        self.maximize = maximize

        self.lookahead_step: int = 0
        self.starting_lr: float = lr
        self.current_lr: float = lr

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

        self.num_warm_up_iterations: int = (
            self.build_warm_up_iterations(num_iterations, betas[1])
            if num_warm_up_iterations is None
            else num_warm_up_iterations
        )
        self.num_warm_down_iterations: int = (
            self.build_warm_down_iterations(num_iterations)
            if num_warm_down_iterations is None
            else num_warm_down_iterations
        )
        self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
        self.warm_down_lr_delta: float = self.starting_lr - self.min_lr

    def __str__(self) -> str:
        return 'Ranger21'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['grad_ma'] = torch.zeros_like(p)
                state['variance_ma'] = torch.zeros_like(p)
                state['lookahead_params'] = p.clone()
                state['neg_grad_ma'] = torch.zeros_like(p)
                state['max_variance_ma'] = torch.zeros_like(p)

    @staticmethod
    def build_warm_up_iterations(total_iterations: int, beta2: float, warm_up_pct: float = 0.22) -> int:
        warm_up_iterations: int = math.ceil(2.0 / (1.0 - beta2))  # default un-tuned linear warmup
        beta_pct: float = warm_up_iterations / total_iterations
        return int(warm_up_pct * total_iterations) if beta_pct > 0.45 else warm_up_iterations

    @staticmethod
    def build_warm_down_iterations(total_iterations: int, warm_down_pct: float = 0.72) -> int:
        start_warm_down: int = int(warm_down_pct * total_iterations)
        return total_iterations - start_warm_down

    def warm_up_dampening(self, lr: float, step: int) -> float:
        if step > self.num_warm_up_iterations:
            return lr

        warm_up_current_pct: float = min(1.0, (step / self.num_warm_up_iterations))

        self.current_lr = lr * warm_up_current_pct

        return self.current_lr

    def warm_down(self, lr: float, iteration: int) -> float:
        if iteration < self.start_warm_down:
            return lr

        # start iteration from 1, not 0
        warm_down_iteration: int = max((iteration + 1) - self.start_warm_down, 1)
        warm_down_pct: float = min(warm_down_iteration / (self.num_warm_down_iterations + 1), 1.0)

        self.current_lr = max(self.starting_lr - self.warm_down_lr_delta * warm_down_pct, self.min_lr)

        return self.current_lr

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        variance_ma_sum: float = 1.0

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad.copy_(agc(p, grad, self.agc_eps, self.agc_clipping_value))

                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                variance_ma = state['variance_ma']
                variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                variance_ma_sum += (variance_ma / bias_correction2).sum()

        if param_size == 0:
            raise ZeroParameterSizeError

        variance_normalized = math.sqrt(variance_ma_sum / param_size)

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            if self.disable_lr_scheduler:
                lr: float = group['lr']
            else:
                lr: float = self.warm_up_dampening(group['lr'], group['step'])
                lr = self.warm_down(lr, group['step'])

            step_size: float = self.apply_adam_debias(group.get('adam_debias', False), lr, bias_correction1)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / variance_normalized,
                )

                correction = 2.0 * self.norm_loss_factor * (1.0 - 1.0 / unit_norm(p).add_(group['eps']))
                p.mul_(1.0 - lr * correction)

                state = self.state[p]
                if group['step'] % 2 == 1:
                    grad_ma, neg_grad_ma = state['grad_ma'], state['neg_grad_ma']
                else:
                    grad_ma, neg_grad_ma = state['neg_grad_ma'], state['grad_ma']

                variance_ma = state['variance_ma']
                torch.max(state['max_variance_ma'], variance_ma, out=variance_ma)

                de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps'])

                if self.use_softplus:
                    de_nom = softplus(de_nom, beta=self.beta_softplus)

                grad = p.grad
                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                grad_ma.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2)  # fmt: skip

                pn_momentum = grad_ma.mul(2.0).add_(neg_grad_ma, alpha=-1.0).mul_(1.0 / noise_norm)
                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        self.lookahead_process_step()

        return loss

    def lookahead_process_step(self):
        self.lookahead_step += 1
        if self.lookahead_step >= self.lookahead_merge_time:
            self.lookahead_step: int = 0
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    state = self.state[p]

                    p.mul_(self.lookahead_blending_alpha).add_(
                        state['lookahead_params'],
                        alpha=1.0 - self.lookahead_blending_alpha,
                    )
                    state['lookahead_params'].copy_(p)

Ranger25

Bases: BaseOptimizer

Mixin' every fancy optimizer hacks.

Here's the components: * ADOPT * AdEMAMix * Cautious * StableAdamW or Adam-atan2 * OrthoGrad * Adaptive gradient clipping * Lookahead * Cautious Weight Decay

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.98, 0.9999)
weight_decay float

Weight decay (L2 penalty).

0.001
alpha float

Usually between 4 and 10 works well.

5.0
t_alpha_beta3 Optional[float]

Total number of iterations is preferred when needed.

None
cautious bool

Whether to use the Cautious variant.

True
stable_adamw bool

Whether to use stable AdamW variant.

True
orthograd bool

Whether to use OrthoGrad variant.

True
eps Optional[float]

Term added to the denominator to improve numerical stability. When eps is None and stable_adamw is False, adam-atan2 feature will be used.

1e-08
maximize bool

Maximize the objective w.r.t the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/experimental/ranger25.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
class Ranger25(BaseOptimizer):
    """Mixin' every fancy optimizer hacks.

    Here's the components:
        * ADOPT
        * AdEMAMix
        * Cautious
        * StableAdamW or Adam-atan2
        * OrthoGrad
        * Adaptive gradient clipping
        * Lookahead
        * Cautious Weight Decay

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        alpha (float): Usually between 4 and 10 works well.
        t_alpha_beta3 (Optional[float]): Total number of iterations is preferred when needed.
        cautious (bool): Whether to use the Cautious variant.
        stable_adamw (bool): Whether to use stable AdamW variant.
        orthograd (bool): Whether to use OrthoGrad variant.
        eps (Optional[float]): Term added to the denominator to improve numerical stability.
            When eps is None and stable_adamw is False, adam-atan2 feature will be used.
        maximize (bool): Maximize the objective w.r.t the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.98, 0.9999),
        weight_decay: float = 1e-3,
        alpha: float = 5.0,
        t_alpha_beta3: Optional[float] = None,
        lookahead_merge_time: int = 5,
        lookahead_blending_alpha: float = 0.5,
        cautious: bool = True,
        stable_adamw: bool = True,
        orthograd: bool = True,
        eps: Optional[float] = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(alpha, 'alpha')
        self.validate_non_negative(t_alpha_beta3, 't_alpha_beta3')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(lookahead_merge_time, 'lookahead_merge_time')
        self.validate_range(lookahead_blending_alpha, 'lookahead_blending_alpha', 0.0, 1.0, '[]')
        self.validate_non_negative(eps, 'eps')

        self.lookahead_merge_time = lookahead_merge_time
        self.lookahead_blending_alpha = lookahead_blending_alpha
        self.cautious = cautious
        self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False
        self.orthograd = orthograd
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'alpha': alpha,
            't_alpha_beta3': t_alpha_beta3,
            'eps': eps if (eps is not None) or (eps is None and not stable_adamw) else 1e-8,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Ranger25'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)
                state['exp_avg_slow'] = torch.zeros_like(grad)
                state['slow_momentum'] = p.clone()

    @staticmethod
    def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
        return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)

    @staticmethod
    def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
        if t_alpha_beta3 is None:
            return beta3

        log_beta1, log_beta3 = math.log(beta1), math.log(beta3)

        return min(
            math.exp(
                log_beta1 * log_beta3 / ((1.0 - step / t_alpha_beta3) * log_beta3 + (step / t_alpha_beta3) * log_beta1)
            ),
            beta3,
        )

    @torch.no_grad()
    def apply_orthogonal_gradients(self, params, eps: float = 1e-16) -> None:
        for p in params:
            if p.grad is None or p.grad.is_sparse or torch.is_complex(p):
                continue

            w = p.view(-1)
            g = p.grad.view(-1)

            proj = torch.dot(w, g).div_(torch.dot(w, w).add_(eps))
            g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
            g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(eps)))

            p.grad.copy_(g_ortho_scaled.view_as(p.grad))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        if self.orthograd:
            for group in self.param_groups:
                self.apply_orthogonal_gradients(group['params'])

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] / bias_correction1
            clip: float = math.pow(group['step'], 0.25)

            alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
            beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad.copy_(agc(p, grad))

                exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

                normed_grad = grad.div(
                    exp_avg_sq.sqrt().clamp_(min=group['eps'] if group['eps'] is not None else 1e-8)
                ).clamp_(-clip, clip)

                exp_avg.mul_(beta1).add_(normed_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                exp_avg_slow.mul_(beta3_t).add_(normed_grad, alpha=1.0 - beta3_t)

                update = exp_avg.clone()

                self.apply_cautious_weight_decay(p, update, group['lr'], group['weight_decay'])

                if self.cautious:
                    self.apply_cautious(update, grad)

                if self.stable_adamw:
                    step_size /= self.get_stable_adamw_rms(grad, exp_avg_sq)

                update.add_(exp_avg_slow, alpha=alpha_t)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq)

                if group['eps'] is not None:
                    p.addcdiv_(update, de_nom.add_(group['eps']), value=-step_size)
                else:
                    p.add_(update.atan2_(de_nom), alpha=-step_size)

                if group['step'] % self.lookahead_merge_time == 0:
                    slow_p = state['slow_momentum']
                    slow_p.lerp_(p, weight=self.lookahead_blending_alpha)
                    p.copy_(slow_p)

        return loss

ROSE

Bases: BaseOptimizer

Range-Of-Slice Equilibration optimizer.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
weight_decay float

Weight decay (L2 penalty).

0.0001
wd_schedule Union[float, bool]

Schedule-Coupled Weight Decay. If False, standard decoupled weight decay is used. If True, lr_ref is the first available among group['max_lr'], group['initial_lr'], and the learning-rate passed at construction time. If a float is provided, it is used directly as lr_ref.

False
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Fix weight decay.

False
centralize bool

Gradient Centralization. Removes shared offsets from gradient slices before the range computation. This can improve generalization and training stability. Biases and other 1D parameters are not centralized.

True
stabilize bool

Coefficient-of-Variation Trust Gating. Computes a trust factor from the coefficient of variation of the per-slice range tensor, and then interpolates between the local range and a smoother global mean denominator. This can smooth noisy gradients.

True
bf16_sr bool

Stochastic Rounding for BFloat16.

True
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/rose.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class ROSE(BaseOptimizer):
    """Range-Of-Slice Equilibration optimizer.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        weight_decay (float): Weight decay (L2 penalty).
        wd_schedule (Union[float, bool]): Schedule-Coupled Weight Decay. If `False`, standard decoupled weight decay is
            used. If `True`, `lr_ref` is the first available among `group['max_lr']`, `group['initial_lr']`, and the
            learning-rate passed at construction time. If a float is provided, it is used directly as `lr_ref`.
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        centralize (bool): Gradient Centralization. Removes shared offsets from gradient slices before the range
            computation. This can improve generalization and training stability. Biases and other 1D parameters are not
            centralized.
        stabilize (bool): Coefficient-of-Variation Trust Gating. Computes a trust factor from the coefficient of
            variation of the per-slice range tensor, and then interpolates between the local range and a smoother
            global mean denominator. This can smooth noisy gradients.
        bf16_sr (bool): Stochastic Rounding for BFloat16.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        wd_schedule: Union[bool, float] = False,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        centralize: bool = True,
        stabilize: bool = True,
        bf16_sr: bool = True,
        compute_dtype: torch.dtype = torch.float64,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        if bf16_sr and compute_dtype not in (torch.float32, torch.float64, None):
            raise ValueError(f'bf16_sr=True has no useful effect when compute_dtype is {compute_dtype}.')

        defaults: Defaults = {
            'lr': lr,
            'weight_decay': weight_decay,
            'wd_schedule': wd_schedule,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'centralize': centralize,
            'stabilize': stabilize,
            'bf16_sr': bf16_sr,
            'compute_dtype': compute_dtype,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ROSE'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            lr = group['lr']
            weight_decay, wd_schedule = group['weight_decay'], group['wd_schedule']
            compute_dtype = group['compute_dtype']

            if weight_decay and wd_schedule:
                wd_lr = lr / (
                    wd_schedule if isinstance(wd_schedule, float) else group.get('max_lr', group.get('initial_lr'))
                )
            else:
                wd_lr = lr

            for p in group['params']:
                if p.grad is None:
                    continue

                use_bf16_sr = group['bf16_sr'] and p.dtype is torch.bfloat16
                fp32 = use_bf16_sr and not compute_dtype

                grad = p.grad.to(dtype=torch.float32 if fp32 else compute_dtype)
                param = p.to(dtype=torch.float32 if fp32 else compute_dtype)

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=wd_lr,
                    weight_decay=weight_decay,
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if grad.ndim == 0:
                    param.add_(grad.sign(), alpha=-lr)
                elif grad.ndim == 1:
                    g_min, g_max = grad.aminmax()
                    de_nom = g_max.abs_().sub_(g_min)

                    de_nom.masked_fill_(de_nom == 0.0, 1.0)
                    param.addcdiv_(grad, de_nom, value=-lr)
                else:
                    active_axes = tuple(range(1, grad.ndim))

                    if group['centralize']:
                        if grad is not p.grad:
                            grad.sub_(grad.mean(dim=active_axes, keepdim=True))
                        else:
                            grad = grad.sub(grad.mean(dim=active_axes, keepdim=True))

                    raw_scale = (
                        grad.amax(dim=active_axes, keepdim=True).abs_().sub_(grad.amin(dim=active_axes, keepdim=True))
                    )

                    if group['stabilize']:
                        std, mean = torch.std_mean(raw_scale, correction=0)

                        trust = mean.div(std.add_(mean).masked_fill_(mean == 0.0, 1.0))

                        de_nom = mean.lerp(raw_scale, trust)
                    else:
                        de_nom = raw_scale

                    de_nom.masked_fill_(de_nom == 0.0, 1.0)
                    param.addcdiv_(grad, de_nom, value=-lr)

                if use_bf16_sr:
                    param = param.to(dtype=torch.float32)

                    copy_stochastic(p, param)
                elif param is not p:
                    p.copy_(param)

        return loss

RotoGrad

Bases: RotateOnly

Implementation of RotoGrad as described in the original paper.

Parameters:

Name Type Description Default
backbone Module

shared module.

required
heads Sequence[Module]

task-specific modules.

required
latent_size int

size of the shared representation, size of the output of the backbone.z.

required
burn_in_period int

When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, L_k(t) / L_k(t_0=0). This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, t_0 = burn\_in\_period. This is done to overcome problems with losses quickly changing in the first iterations.

20
normalize_losses bool

Whether to use these normalized losses to back-propagate through the task-specific parameters as well.

False
Source code in pytorch_optimizer/optimizer/rotograd.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class RotoGrad(RotateOnly):
    r"""Implementation of RotoGrad as described in the original paper.

    Args:
        backbone (nn.Module): shared module.
        heads (Sequence[nn.Module]): task-specific modules.
        latent_size (int): size of the shared representation, size of the output of the backbone.z.
        burn_in_period (int): When back-propagating towards the shared parameters, each task loss is normalized
            dividing by its initial value, \(L_k(t) / L_k(t_0=0)\). This parameter sets a number of iterations
            after which the denominator will be replaced by the value of the loss at that iteration, that is,
            \(t_0 = burn\_in\_period\). This is done to overcome problems with losses quickly changing
            in the first iterations.
        normalize_losses (bool): Whether to use these normalized losses to back-propagate through the task-specific
            parameters as well.

    """

    num_tasks: int
    backbone: nn.Module
    heads: Sequence[nn.Module]
    rep: torch.Tensor

    def __init__(
        self,
        backbone: nn.Module,
        heads: Sequence[nn.Module],
        latent_size: int,
        *args,
        burn_in_period: int = 20,
        normalize_losses: bool = False,
    ):
        super().__init__(backbone, heads, latent_size, burn_in_period, *args, normalize_losses=normalize_losses)

        self.initial_grads = None
        self.counter: int = 0

    def _rep_grad(self):
        super()._rep_grad()

        grad_norms = [torch.linalg.norm(g, keepdim=True).clamp_min(1e-15) for g in self.original_grads]
        if self.initial_grads is None or self.counter == self.burn_in_period:
            self.initial_grads = grad_norms
            conv_ratios = [torch.ones((1,)) for _ in range(len(self.initial_grads))]
        else:
            conv_ratios = [x / y for x, y in zip(grad_norms, self.initial_grads)]

        self.counter += 1

        alphas = [x / torch.clamp(sum(conv_ratios), 1e-15) for x in conv_ratios]
        weighted_sum_norms = sum(a * g for a, g in zip(alphas, grad_norms))

        return sum(g / n * weighted_sum_norms for g, n in zip(self.original_grads, grad_norms))

SafeFP16Optimizer

Bases: Optimizer

Safe FP16 Optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer instance.

required
aggregate_g_norms bool

Aggregate gradient norms.

False
min_loss_scale float

Minimum loss scale.

2 ** -5
Source code in pytorch_optimizer/optimizer/fp16.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class SafeFP16Optimizer(Optimizer):  # pragma: no cover
    """Safe FP16 Optimizer.

    Args:
        optimizer (Optimizer): Optimizer instance.
        aggregate_g_norms (bool): Aggregate gradient norms.
        min_loss_scale (float): Minimum loss scale.

    """

    def __init__(
        self,
        optimizer: Optimizer,
        aggregate_g_norms: bool = False,
        min_loss_scale: float = 2 ** -5,
    ) -> None:  # fmt: skip
        self.optimizer = optimizer
        self.aggregate_g_norms = aggregate_g_norms
        self.min_loss_scale = min_loss_scale

        self.fp16_params = self.get_parameters(optimizer)
        self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False)

        # we want the optimizer to be tracking the fp32 parameters
        if len(optimizer.param_groups) != 1:
            # future implementers: this should hopefully be a matter of just iterating through the param groups and
            # keeping track of the pointer through the fp32_params
            raise NotImplementedError('Need to implement the parameter group transfer.')

        optimizer.param_groups[0]['params'] = self.fp32_params

        self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15)  # fmt: skip
        self.needs_sync: bool = True

    @classmethod
    def get_parameters(cls, optimizer: Optimizer) -> List:
        params: List = []
        for group in optimizer.param_groups:
            params += list(group['params'])
        return params

    @classmethod
    def build_fp32_params(cls, parameters: ParamsT, flatten: bool = True) -> Union[torch.Tensor, List[torch.Tensor]]:
        parameters = cast(List[torch.Tensor], parameters)

        if flatten:
            total_param_size: int = sum(p.numel() for p in parameters)
            fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)  # type: ignore

            offset: int = 0
            for p in parameters:
                p_num_el = p.numel()
                fp32_params[offset:offset + p_num_el].copy_(p.view(-1))  # fmt: skip
                offset += p_num_el

            fp32_params = nn.Parameter(fp32_params)  # type: ignore
            fp32_params.grad = fp32_params.new(total_param_size)

            return fp32_params

        fp32_params: List[torch.Tensor] = []
        for p in parameters:
            p32 = nn.Parameter(p.float())
            p32.grad = torch.zeros_like(p32)
            fp32_params.append(p32)

        return fp32_params

    def state_dict(self) -> Dict:
        """Return the optimizer state dict."""
        state_dict = self.optimizer.state_dict()
        if self.scaler is not None:
            state_dict['loss_scaler'] = self.scaler.loss_scale
        return state_dict

    def load_state_dict(self, state_dict: Dict):
        """Load an optimizer state dict.

        In general, prefer using the existing optimizer instance's configuration (e.g., learning rate)
        over the values found in the state_dict. This approach allows resuming training from a checkpoint
        while applying new optimizer arguments.

        Args:
            state_dict (dict): The state dictionary to load into the optimizer.

        """
        if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
            self.scaler.loss_scale = state_dict['loss_scaler']
        self.optimizer.load_state_dict(state_dict)

    def backward(self, loss, update_main_grads: bool = False):
        """Compute the sum of gradients of the given tensor w.r.t. graph leaves.

        Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
        additionally dynamically scales the loss to avoid gradient underflow.

        Args:
            loss (float): The loss tensor to backpropagate.
            update_main_grads (bool): Whether to update the main gradient during backpropagation.

        """
        if self.scaler is not None:
            loss = loss * self.scaler.loss_scale

        loss.backward()

        self.needs_sync = True
        if update_main_grads:
            self.update_main_grads()

    def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0) -> None:
        """Sync fp16 to fp32 gradients."""
        if self.needs_sync:
            if self.scaler is not None:
                multiply_grads /= self.scaler.loss_scale

            for p16, p32 in zip(self.fp16_params, self.fp32_params):
                if not p16.requires_grad:
                    continue

                if p16.grad is not None:
                    p32.grad.copy_(p16.grad)
                    p32.grad.mul_(multiply_grads)
                else:
                    p32.grad = torch.zeros_like(p16, dtype=torch.float)

            self.needs_sync = False

    def multiply_grads(self, c: float) -> None:
        """Multiply grads by a constant c."""
        if self.needs_sync:
            self.sync_fp16_grads_to_fp32(c)
            return

        for p32 in self.fp32_params:
            p32.grad.mul_(c)

    def update_main_grads(self) -> None:
        self.sync_fp16_grads_to_fp32()

    def clip_main_grads(self, max_norm: float):
        """Clip gradient norm and updates dynamic loss scaler."""
        self.sync_fp16_grads_to_fp32()

        grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

        # detect overflow and adjust loss scale
        if self.scaler is not None:
            overflow: bool = has_overflow(grad_norm)
            prev_scale: float = self.scaler.loss_scale

            self.scaler.update_scale(overflow)

            if overflow:
                self.zero_grad()
                if self.scaler.loss_scale <= self.min_loss_scale:
                    # Use FloatingPointError as an uncommon error
                    # that parent functions can safely catch to stop training.
                    self.scaler.loss_scale = prev_scale

                    raise FloatingPointError(
                        f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                        'Try lowering the learning rate, using gradient clipping or increasing the batch size.\n'
                        f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                    )

        return grad_norm

    def step(self, closure: Closure = None):
        """Perform a single optimization step."""
        self.sync_fp16_grads_to_fp32()
        self.optimizer.step(closure)

        for p16, p32 in zip(self.fp16_params, self.fp32_params):
            if not p16.requires_grad:
                continue
            p16.data.copy_(p32)

    def zero_grad(self) -> None:
        """Clear the gradients of all optimized parameters."""
        for p16 in self.fp16_params:
            p16.grad = None
        for p32 in self.fp32_params:
            p32.grad.zero_()
        self.needs_sync = False

    def get_lr(self) -> float:
        """Get learning rate."""
        return self.optimizer.get_lr()

    def set_lr(self, lr: float):
        """Set learning rate."""
        self.optimizer.set_lr(lr)

    @property
    def loss_scale(self) -> float:
        """Convenience function which TorchAgent calls to get current scale value."""
        return self.scaler.loss_scale

loss_scale property

Convenience function which TorchAgent calls to get current scale value.

backward(loss, update_main_grads=False)

Compute the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to :func:fairseq.optim.FairseqOptimizer.backward, this function additionally dynamically scales the loss to avoid gradient underflow.

Parameters:

Name Type Description Default
loss float

The loss tensor to backpropagate.

required
update_main_grads bool

Whether to update the main gradient during backpropagation.

False
Source code in pytorch_optimizer/optimizer/fp16.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def backward(self, loss, update_main_grads: bool = False):
    """Compute the sum of gradients of the given tensor w.r.t. graph leaves.

    Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
    additionally dynamically scales the loss to avoid gradient underflow.

    Args:
        loss (float): The loss tensor to backpropagate.
        update_main_grads (bool): Whether to update the main gradient during backpropagation.

    """
    if self.scaler is not None:
        loss = loss * self.scaler.loss_scale

    loss.backward()

    self.needs_sync = True
    if update_main_grads:
        self.update_main_grads()

clip_main_grads(max_norm)

Clip gradient norm and updates dynamic loss scaler.

Source code in pytorch_optimizer/optimizer/fp16.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def clip_main_grads(self, max_norm: float):
    """Clip gradient norm and updates dynamic loss scaler."""
    self.sync_fp16_grads_to_fp32()

    grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

    # detect overflow and adjust loss scale
    if self.scaler is not None:
        overflow: bool = has_overflow(grad_norm)
        prev_scale: float = self.scaler.loss_scale

        self.scaler.update_scale(overflow)

        if overflow:
            self.zero_grad()
            if self.scaler.loss_scale <= self.min_loss_scale:
                # Use FloatingPointError as an uncommon error
                # that parent functions can safely catch to stop training.
                self.scaler.loss_scale = prev_scale

                raise FloatingPointError(
                    f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or increasing the batch size.\n'
                    f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                )

    return grad_norm

get_lr()

Get learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
282
283
284
def get_lr(self) -> float:
    """Get learning rate."""
    return self.optimizer.get_lr()

load_state_dict(state_dict)

Load an optimizer state dict.

In general, prefer using the existing optimizer instance's configuration (e.g., learning rate) over the values found in the state_dict. This approach allows resuming training from a checkpoint while applying new optimizer arguments.

Parameters:

Name Type Description Default
state_dict dict

The state dictionary to load into the optimizer.

required
Source code in pytorch_optimizer/optimizer/fp16.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def load_state_dict(self, state_dict: Dict):
    """Load an optimizer state dict.

    In general, prefer using the existing optimizer instance's configuration (e.g., learning rate)
    over the values found in the state_dict. This approach allows resuming training from a checkpoint
    while applying new optimizer arguments.

    Args:
        state_dict (dict): The state dictionary to load into the optimizer.

    """
    if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
        self.scaler.loss_scale = state_dict['loss_scaler']
    self.optimizer.load_state_dict(state_dict)

multiply_grads(c)

Multiply grads by a constant c.

Source code in pytorch_optimizer/optimizer/fp16.py
224
225
226
227
228
229
230
231
def multiply_grads(self, c: float) -> None:
    """Multiply grads by a constant c."""
    if self.needs_sync:
        self.sync_fp16_grads_to_fp32(c)
        return

    for p32 in self.fp32_params:
        p32.grad.mul_(c)

set_lr(lr)

Set learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
286
287
288
def set_lr(self, lr: float):
    """Set learning rate."""
    self.optimizer.set_lr(lr)

state_dict()

Return the optimizer state dict.

Source code in pytorch_optimizer/optimizer/fp16.py
164
165
166
167
168
169
def state_dict(self) -> Dict:
    """Return the optimizer state dict."""
    state_dict = self.optimizer.state_dict()
    if self.scaler is not None:
        state_dict['loss_scaler'] = self.scaler.loss_scale
    return state_dict

step(closure=None)

Perform a single optimization step.

Source code in pytorch_optimizer/optimizer/fp16.py
264
265
266
267
268
269
270
271
272
def step(self, closure: Closure = None):
    """Perform a single optimization step."""
    self.sync_fp16_grads_to_fp32()
    self.optimizer.step(closure)

    for p16, p32 in zip(self.fp16_params, self.fp32_params):
        if not p16.requires_grad:
            continue
        p16.data.copy_(p32)

sync_fp16_grads_to_fp32(multiply_grads=1.0)

Sync fp16 to fp32 gradients.

Source code in pytorch_optimizer/optimizer/fp16.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0) -> None:
    """Sync fp16 to fp32 gradients."""
    if self.needs_sync:
        if self.scaler is not None:
            multiply_grads /= self.scaler.loss_scale

        for p16, p32 in zip(self.fp16_params, self.fp32_params):
            if not p16.requires_grad:
                continue

            if p16.grad is not None:
                p32.grad.copy_(p16.grad)
                p32.grad.mul_(multiply_grads)
            else:
                p32.grad = torch.zeros_like(p16, dtype=torch.float)

        self.needs_sync = False

zero_grad()

Clear the gradients of all optimized parameters.

Source code in pytorch_optimizer/optimizer/fp16.py
274
275
276
277
278
279
280
def zero_grad(self) -> None:
    """Clear the gradients of all optimized parameters."""
    for p16 in self.fp16_params:
        p16.grad = None
    for p32 in self.fp32_params:
        p32.grad.zero_()
    self.needs_sync = False

SAM

Bases: BaseOptimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

base optimizer.

required
rho float

size of the neighborhood for computing the max loss.

0.05
adaptive bool

element-wise Adaptive SAM.

False
use_gc bool

perform gradient centralization, GCSAM variant.

False
perturb_eps float

eps for perturbation.

1e-12
kwargs Dict

parameters for optimizer.

{}
Example
model = YourModel()
base_optimizer = Ranger21
optimizer = SAM(model.parameters(), base_optimizer)
for input, output in data:
    # first forward-backward pass
    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    # make sure to do a full forward pass
    loss_function(output, model(input)).backward()
    optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

model = YourModel()
base_optimizer = Ranger21
optimizer = SAM(model.parameters(), base_optimizer)

def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

for input, output in data:
    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.step(closure)
    optimizer.zero_grad()
Source code in pytorch_optimizer/optimizer/sam.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class SAM(BaseOptimizer):
    """Sharpness-Aware Minimization for Efficiently Improving Generalization.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        base_optimizer (Optimizer): base optimizer.
        rho (float): size of the neighborhood for computing the max loss.
        adaptive (bool): element-wise Adaptive SAM.
        use_gc (bool): perform gradient centralization, GCSAM variant.
        perturb_eps (float): eps for perturbation.
        kwargs (Dict): parameters for optimizer.

    Example:
        ```python
        model = YourModel()
        base_optimizer = Ranger21
        optimizer = SAM(model.parameters(), base_optimizer)
        for input, output in data:
            # first forward-backward pass
            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward pass
            # make sure to do a full forward pass
            loss_function(output, model(input)).backward()
            optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

        model = YourModel()
        base_optimizer = Ranger21
        optimizer = SAM(model.parameters(), base_optimizer)

        def closure():
            loss = loss_function(output, model(input))
            loss.backward()
            return loss

        for input, output in data:
            loss = loss_function(output, model(input))
            loss.backward()
            optimizer.step(closure)
            optimizer.zero_grad()
        ```

    """

    def __init__(
        self,
        params: ParamsT,
        base_optimizer: OptimizerType,
        rho: float = 0.05,
        adaptive: bool = False,
        use_gc: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')
        self.validate_non_negative(perturb_eps, 'perturb_eps')

        self.use_gc = use_gc
        self.perturb_eps = perturb_eps

        defaults: Defaults = {'rho': rho, 'adaptive': adaptive, **kwargs}

        super().__init__(params, defaults)

        self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'SAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)

        for group in self.param_groups:
            scale = group['rho'] / grad_norm

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                self.state[p]['old_p'] = p.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: Closure = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        with torch.enable_grad():
            closure()

        self.second_step()

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

ScalableShampoo

Bases: BaseOptimizer

Scalable Preconditioned Stochastic Tensor Optimization.

This version of the Scalable Shampoo Optimizer targets single GPU environments, computing pre-conditioners synchronously on GPU (which takes most of the optimization time). It is faster than previous Shampoo implementations by using coupled Newton iteration for matrix inverse powers instead of slow SVD calculations.

Features include: 1. Various plug-ins (e.g., gradient grafting, preconditioning types), 2. Additional features beyond official PyTorch code, 3. Readable and well-organized implementation.

Reference: https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py

Parameters:

Name Type Description Default
params ParamsT

iterable or dicts defining parameter groups.

required
lr float

learning rate.

0.001
betas tuple

beta1 and beta2 for momentum.

(0.9, 0.999)
moving_average_for_momentum bool

whether to perform moving average for momentum (beta1).

False
weight_decay float

weight decay (L2 penalty).

0.0
decoupled_weight_decay bool

use decoupled weight decay.

False
decoupled_learning_rate bool

use decoupled learning rate, otherwise coupled with preconditioned gradient.

True
inverse_exponent_override int

fixed exponent for preconditioner if > 0.

0
start_preconditioning_step int

step to start preconditioning.

25
preconditioning_compute_steps int

frequency of preconditioner computation.

1000
statistics_compute_steps int

frequency of statistics computation.

1
block_size int

block size for large layers; 1 means AdaGrad (inefficient).

512
skip_preconditioning_rank_lt int

skip preconditioning for layers with rank below this.

1
no_preconditioning_for_layers_with_dim_gt int

avoid preconditioning large layers.

8192
shape_interpretation bool

automatic shape interpretation for tensor dims.

True
graft_type int

type of grafting (SGD, AdaGrad, RMSProp, etc.).

SGD
pre_conditioner_type int

type of preconditioner.

ALL
nesterov bool

enable Nesterov momentum.

True
diagonal_eps float

epsilon for numerical stability in diagonal.

1e-10
matrix_eps float

epsilon for numerical stability in matrix.

1e-06
use_svd bool

whether to use SVD for matrix inverse powers (alternative is Schur-Newton).

False
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/shampoo.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
class ScalableShampoo(BaseOptimizer):
    """Scalable Preconditioned Stochastic Tensor Optimization.

    This version of the Scalable Shampoo Optimizer targets single GPU environments,
    computing pre-conditioners synchronously on GPU (which takes most of the optimization time).
    It is faster than previous Shampoo implementations by using coupled Newton iteration
    for matrix inverse powers instead of slow SVD calculations.

    Features include:
    1. Various plug-ins (e.g., gradient grafting, preconditioning types),
    2. Additional features beyond official PyTorch code,
    3. Readable and well-organized implementation.

    Reference:
    https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py

    Args:
        params (ParamsT): iterable or dicts defining parameter groups.
        lr (float): learning rate.
        betas (tuple): beta1 and beta2 for momentum.
        moving_average_for_momentum (bool): whether to perform moving average for momentum (beta1).
        weight_decay (float): weight decay (L2 penalty).
        decoupled_weight_decay (bool): use decoupled weight decay.
        decoupled_learning_rate (bool): use decoupled learning rate, otherwise coupled with preconditioned gradient.
        inverse_exponent_override (int): fixed exponent for preconditioner if > 0.
        start_preconditioning_step (int): step to start preconditioning.
        preconditioning_compute_steps (int): frequency of preconditioner computation.
        statistics_compute_steps (int): frequency of statistics computation.
        block_size (int): block size for large layers; 1 means AdaGrad (inefficient).
        skip_preconditioning_rank_lt (int): skip preconditioning for layers with rank below this.
        no_preconditioning_for_layers_with_dim_gt (int): avoid preconditioning large layers.
        shape_interpretation (bool): automatic shape interpretation for tensor dims.
        graft_type (int): type of grafting (SGD, AdaGrad, RMSProp, etc.).
        pre_conditioner_type (int): type of preconditioner.
        nesterov (bool): enable Nesterov momentum.
        diagonal_eps (float): epsilon for numerical stability in diagonal.
        matrix_eps (float): epsilon for numerical stability in matrix.
        use_svd (bool): whether to use SVD for matrix inverse powers (alternative is Schur-Newton).
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        moving_average_for_momentum: bool = False,
        weight_decay: float = 0.0,
        decoupled_weight_decay: bool = False,
        decoupled_learning_rate: bool = True,
        inverse_exponent_override: int = 0,
        start_preconditioning_step: int = 25,
        preconditioning_compute_steps: int = 1000,
        statistics_compute_steps: int = 1,
        block_size: int = 512,
        skip_preconditioning_rank_lt: int = 1,
        no_preconditioning_for_layers_with_dim_gt: int = 8192,
        shape_interpretation: bool = True,
        graft_type: int = LayerWiseGrafting.SGD,
        pre_conditioner_type: int = PreConditionerType.ALL,
        nesterov: bool = True,
        diagonal_eps: float = 1e-10,
        matrix_eps: float = 1e-6,
        use_svd: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(start_preconditioning_step, 'start_preconditioning_step')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_step(statistics_compute_steps, 'statistics_compute_steps')
        self.validate_non_negative(diagonal_eps, 'diagonal_eps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.inverse_exponent_override = inverse_exponent_override
        self.start_preconditioning_step = start_preconditioning_step
        self.preconditioning_compute_steps = preconditioning_compute_steps
        self.statistics_compute_steps = statistics_compute_steps
        self.block_size = block_size
        self.skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
        self.no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
        self.shape_interpretation = shape_interpretation
        self.graft_type = graft_type
        self.pre_conditioner_type = pre_conditioner_type
        self.diagonal_eps = diagonal_eps
        self.matrix_eps = matrix_eps
        self.use_svd = use_svd
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'decoupled_weight_decay': decoupled_weight_decay,
            'decoupled_learning_rate': decoupled_learning_rate,
            'moving_average_for_momentum': moving_average_for_momentum,
            'nesterov': nesterov,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ScalableShampoo'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        _, beta2 = group['betas']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)
                state['pre_conditioner'] = PreConditioner(
                    p,
                    beta2,
                    self.inverse_exponent_override,
                    self.block_size,
                    self.skip_preconditioning_rank_lt,
                    self.no_preconditioning_for_layers_with_dim_gt,
                    self.shape_interpretation,
                    self.pre_conditioner_type,
                    self.matrix_eps,
                    self.use_svd,
                )
                state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)

    def is_precondition_step(self, step: int) -> bool:
        return step >= self.start_preconditioning_step

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            is_precondition_step: bool = self.is_precondition_step(group['step'])
            pre_conditioner_multiplier: float = 1.0 if group['decoupled_learning_rate'] else group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                pre_conditioner, graft = state['pre_conditioner'], state['graft']

                graft.add_statistics(grad, beta2)
                if group['step'] % self.statistics_compute_steps == 0:
                    pre_conditioner.add_statistics(grad)
                if group['step'] % self.preconditioning_compute_steps == 0:
                    pre_conditioner.compute_pre_conditioners()

                graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
                shampoo_grad: torch.Tensor = (
                    pre_conditioner.preconditioned_grad(grad) if is_precondition_step else grad
                )

                if self.graft_type != LayerWiseGrafting.NONE:
                    graft_norm = torch.linalg.norm(graft_grad)
                    shampoo_norm = torch.linalg.norm(shampoo_grad)

                    shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

                for g in (graft_grad, shampoo_grad):
                    self.apply_weight_decay(
                        p,
                        grad=g,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['decoupled_weight_decay'],
                        fixed_decay=False,
                    )

                state['momentum'].mul_(beta1).add_(shampoo_grad)
                graft_momentum = graft.update_momentum(grad, beta1)

                momentum_update = state['momentum'] if is_precondition_step else graft_momentum

                if group['nesterov']:
                    w: float = (1.0 - beta1) if group['moving_average_for_momentum'] else 1.0

                    wd_update = shampoo_grad if is_precondition_step else graft_grad
                    wd_update.mul_(w)

                    momentum_update.mul_(beta1).add_(wd_update)

                p.add_(momentum_update, alpha=-group['lr'])

        return loss

ScheduleFreeAdamW

Bases: BaseOptimizer

Schedule-Free AdamW.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.0025
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

weight decay (L2 penalty).

0.0
r float

use polynomial weighting in the average with power r.

0.0
weight_lr_power float

during warmup, weights in the average equal to lr raised to this power; 0 disables weighting.

2.0
warmup_steps int

enables a linear learning rate warmup.

0
decoupling_c int

proposed coefficient in Refined Schedule-Free AdamW optimizer; default around 200.

0
ams_bound bool

whether to use the AMSBound variant.

False
eps float

term added to denominator for numerical stability.

1e-08
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
class ScheduleFreeAdamW(BaseOptimizer):
    """Schedule-Free AdamW.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): weight decay (L2 penalty).
        r (float): use polynomial weighting in the average with power r.
        weight_lr_power (float): during warmup, weights in the average equal to lr raised to this power;
            0 disables weighting.
        warmup_steps (int): enables a linear learning rate warmup.
        decoupling_c (int): proposed coefficient in Refined Schedule-Free AdamW optimizer; default around 200.
        ams_bound (bool): whether to use the AMSBound variant.
        eps (float): term added to denominator for numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2.5e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        warmup_steps: int = 0,
        decoupling_c: int = 0,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(decoupling_c, 'decoupling_c')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'warmup_steps': warmup_steps,
            'decoupling_c': decoupling_c,
            'ams_bound': ams_bound,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
            'use_palm': kwargs.get('use_palm', False),
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'ScheduleFreeAdamW'

    def eval(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - beta1)
                group['train_mode'] = True

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            warmup_steps: int = group['warmup_steps']
            schedule: float = group['step'] / warmup_steps if group['step'] < warmup_steps else 1.0

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            lr: float = group['lr'] * schedule
            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            if group['decoupling_c'] > 0:
                checkpoint = min(1.0, checkpoint * (1.0 - beta1) * group['decoupling_c'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, exp_avg_sq = state['z'], state['exp_avg_sq']

                p, grad, z, exp_avg_sq = self.view_as_real(p, grad, z, exp_avg_sq)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq.div(bias_correction2),
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=lr * (beta1 * (1.0 - checkpoint) - 1))

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeRAdam

Bases: BaseOptimizer

Schedule-Free RAdam.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.0025
betas Betas

coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

weight decay (L2 penalty).

0.0
r float

use polynomial weighting in the average with power r.

0.0
weight_lr_power float

during warmup, weights in the average equal to lr raised to this power; 0 disables weighting.

2.0
silent_sgd_phase bool

if True, disables updates in the early SGD phase, only updates momentum to stabilize training.

True
eps float

term added to denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
class ScheduleFreeRAdam(BaseOptimizer):
    """Schedule-Free RAdam.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        betas (Betas): coefficients used for computing running averages of gradient and the squared hessian trace.
        weight_decay (float): weight decay (L2 penalty).
        r (float): use polynomial weighting in the average with power r.
        weight_lr_power (float): during warmup, weights in the average equal to lr raised to this power;
            0 disables weighting.
        silent_sgd_phase (bool): if True, disables updates in the early SGD phase, only updates momentum to
            stabilize training.
        eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 2.5e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        silent_sgd_phase: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'silent_sgd_phase': silent_sgd_phase,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
            'use_palm': kwargs.get('use_palm', False),
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ScheduleFreeRAdam'

    def eval(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - beta1)
                group['train_mode'] = True

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            lr, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=4,
                degenerated_to_sgd=False,
            )
            if lr < 0.0:
                lr = float(not group['silent_sgd_phase'])

            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            adaptive_y_lr: float = lr * (beta1 * (1.0 - checkpoint) - 1.0)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, exp_avg_sq = state['z'], state['exp_avg_sq']

                p, grad, z, exp_avg_sq = self.view_as_real(p, grad, z, exp_avg_sq)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma > 4.0:
                    de_nom = exp_avg_sq.sqrt().div_(bias_correction2).add_(group['eps'])
                    grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=adaptive_y_lr)

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeSGD

Bases: BaseOptimizer

Schedule-Free SGD.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

1.0
momentum float

momentum factor, must be between 0 and 1 exclusive.

0.9
weight_decay float

weight decay (L2 penalty).

0.0
r float

use polynomial weighting in the average with power r.

0.0
weight_lr_power float

during warmup, weights in the average equal to lr raised to this power; 0 disables weighting.

2.0
warmup_steps int

enables a linear learning rate warmup.

0
eps float

term added to denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class ScheduleFreeSGD(BaseOptimizer):
    """Schedule-Free SGD.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor, must be between 0 and 1 exclusive.
        weight_decay (float): weight decay (L2 penalty).
        r (float): use polynomial weighting in the average with power r.
        weight_lr_power (float): during warmup, weights in the average equal to lr raised to this power;
            0 disables weighting.
        warmup_steps (int): enables a linear learning rate warmup.
        eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1.0,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        warmup_steps: int = 0,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'warmup_steps': warmup_steps,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'ScheduleFreeSGD'

    def eval(self):
        for group in self.param_groups:
            momentum = group['momentum']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            momentum = group['momentum']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - momentum)
                group['train_mode'] = True

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            warmup_steps: int = group['warmup_steps']
            schedule: float = group['step'] / warmup_steps if group['step'] < warmup_steps else 1.0

            momentum = group['momentum']

            lr: float = group['lr'] * schedule
            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z = state['z']

                p, grad, z = self.view_as_real(p, grad, z)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=lr * (momentum * (1.0 - checkpoint) - 1))

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeWrapper

Bases: BaseOptimizer

Schedule-Free Wrapper for any base optimizer.

This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free needs to be directly integrated with the base optimizer.

When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using our wrapper's momentum (although you can use both types of momentum if you want).

If you set weight decay on the base optimizer, it computes weight decay at z. We offer the option to compute weight decay at y, via the weight_decay_at_y parameter, which seems to give better results in our experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current learning rate.

Parameters:

Name Type Description Default
optimizer Optimizer

base optimizer instance or class to wrap.

required
momentum float

momentum factor.

0.9
weight_decay float

weight decay (L2 penalty).

0.0
r float

use polynomial weighting in the average with power r.

0.0
weight_lr_power float

during warmup, weights in average equal lr raised to this power; 0 disables weighting.

2.0
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
class ScheduleFreeWrapper(BaseOptimizer):
    r"""Schedule-Free Wrapper for any base optimizer.

    This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases
    the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free
    needs to be directly integrated with the base optimizer.

    When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using
    our wrapper's momentum (although you can use both types of momentum if you want).

    If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute
    weight decay at $y$, via the `weight_decay_at_y` parameter, which seems to give better results in our
    experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current
    learning rate.

    Args:
        optimizer (Optimizer): base optimizer instance or class to wrap.
        momentum (float): momentum factor.
        weight_decay (float): weight decay (L2 penalty).
        r (float): use polynomial weighting in the average with power r.
        weight_lr_power (float): during warmup, weights in average equal lr raised to this power; 0 disables weighting.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        optimizer: OptimizerInstanceOrClass,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_range(momentum, 'momentum', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.momentum = momentum
        self.weight_decay = weight_decay
        self.r = r
        self.weight_lr_power = weight_lr_power
        self.train_mode: bool = False
        self.maximize = maximize

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.state: State = defaultdict(dict)
        self.defaults: Defaults = self.optimizer.defaults

    def __str__(self) -> str:
        return 'ScheduleFree'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def __getstate__(self):
        return {'state': self.state, 'optimizer': self.optimizer}

    def add_param_group(self, param_group):
        return self.optimizer.add_param_group(param_group)

    def state_dict(self) -> State:
        return {'schedulefree_state': self.state, 'base_optimizer': self.optimizer.state_dict()}

    def load_state_dict(self, state: State) -> None:
        r"""Load state."""
        self.state = state['schedulefree_state']
        self.optimizer.load_state_dict(state['base_optimizer'])

    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none)

    @torch.no_grad()
    def eval(self):
        if not self.train_mode:
            return

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'z' in state:
                    p.lerp_(end=state['z'], weight=1.0 - 1.0 / self.momentum)

        self.train_mode = False

    @torch.no_grad()
    def train(self):
        if self.train_mode:
            return

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'z' in state:
                    p.lerp_(end=state['z'], weight=1.0 - self.momentum)

        self.train_mode = True

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if 'z' not in state:
                state['z'] = p.clone()

    @staticmethod
    def swap(x: torch.Tensor, y: torch.Tensor) -> None:
        x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))
        y.view(torch.uint8).bitwise_xor_(x.view(torch.uint8))
        x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        if not self.train_mode:
            raise ValueError('optimizer was not in train mode when step is called. call .train() before training')

        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z = state['z']

                self.apply_weight_decay(
                    z,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                    ratio=1.0 - self.momentum,
                )

                p.lerp_(end=z, weight=1.0 - 1.0 / self.momentum)

                self.swap(z, p)

        self.optimizer.step()

        for group in self.param_groups:
            lr: float = group['lr'] * group.get('d', 1.0)
            lr_max = group['lr_max'] = max(lr, group.get('lr_max', 0))

            weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power)  # fmt: skip
            weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                z = state['z']

                self.swap(z, p)

                p.lerp_(end=z, weight=checkpoint)

                p.lerp_(end=state['z'], weight=1.0 - self.momentum)

        return loss

load_state_dict(state)

Load state.

Source code in pytorch_optimizer/optimizer/schedulefree.py
580
581
582
583
def load_state_dict(self, state: State) -> None:
    r"""Load state."""
    self.state = state['schedulefree_state']
    self.optimizer.load_state_dict(state['base_optimizer'])

SCION

Bases: BaseOptimizer

Training Deep Learning Models with Norm-Constrained LMOs.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor. 1.0 - usual momentum.

0.1
constraint bool

whether to use a constraint SCG or not.

False
norm_type int

supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7. Please check LMONorm Enum class for the details.

AUTO
norm_kwargs Optional[Dict]

arguments for the Norm.

None
scale float

scale factor. For Transformer block typical value is 50.0, and 3000.0 for others (e.g., Embeddings, LM head).

1.0
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Example

radius = 50.0 parameter_groups = [{ ... 'params': model.transformer.h.parameters(), ... 'norm_type': 'spectral', ... 'norm_kwargs': {}, ... 'scale': radius, ... }, { ... 'params': model.lm_head.parameters(), ... 'norm_type': 'sign', ... 'norm_kwargs': {}, ... 'scale': radius * 60.0, ... }] optimizer = SCION(parameter_groups)

For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

Source code in pytorch_optimizer/optimizer/scion.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
class SCION(BaseOptimizer):
    """Training Deep Learning Models with Norm-Constrained LMOs.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor. 1.0 - usual momentum.
        constraint (bool): whether to use a constraint SCG or not.
        norm_type (int): supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7.
            Please check LMONorm Enum class for the details.
        norm_kwargs (Optional[Dict]): arguments for the Norm.
        scale (float): scale factor. For Transformer block typical value is 50.0, and 3000.0 for others
            (e.g., Embeddings, LM head).
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    Example:
        >>> radius = 50.0
        >>> parameter_groups = [{
        ...     'params': model.transformer.h.parameters(),
        ...     'norm_type': 'spectral',
        ...     'norm_kwargs': {},
        ...     'scale': radius,
        ... }, {
        ...     'params': model.lm_head.parameters(),
        ...     'norm_type': 'sign',
        ...     'norm_kwargs': {},
        ...     'scale': radius * 60.0,
        ... }]
        >>> optimizer = SCION(parameter_groups)

        For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.1,
        constraint: bool = False,
        norm_type: int = LMONorm.AUTO,
        norm_kwargs: Optional[Dict] = None,
        scale: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, '(]')
        self.validate_positive(scale, 'scale')

        self.foreach = foreach
        self.maximize = maximize

        if norm_kwargs is None:
            norm_kwargs = {}

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'constraint': constraint,
            'norm_type': norm_type,
            'norm_kwargs': norm_kwargs,
            'scale': scale,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'foreach': foreach,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SCION'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if 'd' not in state:
                state['d'] = torch.zeros_like(grad)

    @torch.no_grad()
    def init(self):
        for group in self.param_groups:
            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])
            for p in group['params']:
                norm.init(p)
                p.mul_(group['scale'])

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        norm: Norm,
        ds: List[torch.Tensor],
    ) -> None:
        if self.maximize:
            torch._foreach_neg_(grads)

        torch._foreach_lerp_(ds, grads, group['momentum'])

        updates = [norm.lmo(d) for d in ds]
        torch._foreach_mul_(updates, group['scale'])

        if group['constraint']:
            torch._foreach_mul_(params, 1.0 - group['lr'])

        if not group['constraint'] and group['weight_decay'] > 0.0:
            self.apply_weight_decay_foreach(
                params,
                grads=grads,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=False,
            )

        torch._foreach_add_(params, updates, alpha=-group['lr'])

    def _step_per_param(self, group: ParamGroup, norm: Norm) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            d = state['d']

            d.mul_(1.0 - group['momentum']).add_(grad, alpha=group['momentum'])

            update = norm.lmo(d).mul_(group['scale'])

            if group['constraint']:
                p.mul_(1.0 - group['lr'])

            if not group['constraint'] and group['weight_decay'] > 0.0:
                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

            p.add_(update, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(group, self.state, state_keys=['d'])
                if params:
                    self._step_foreach(group, params, grads, norm, state_dict['d'])
            else:
                self._step_per_param(group, norm)

        return loss

SCIONLight

Bases: BaseOptimizer

Memory-efficient variant of the Scion optimizer.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor. 1.0 - usual momentum.

0.1
constraint bool

whether to use a constraint SCG or not.

False
norm_type int

supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7. Please check LMONorm Enum class for the details.

AUTO
norm_kwargs Optional[Dict]

arguments for the Norm.

None
scale float

scale factor. For Transformer block typical value is 50.0, and 3000.0 for others (e.g., Embeddings, LM head).

1.0
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

the optimizer uses decoupled weight decay as in AdamW.

True
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

maximize the objective with respect to the params, instead of minimizing.

False
Example

radius = 50.0 parameter_groups = [{ ... 'params': model.transformer.h.parameters(), ... 'norm_type': 'spectral', ... 'norm_kwargs': {}, ... 'scale': radius, ... }, { ... 'params': model.lm_head.parameters(), ... 'norm_type': 'sign', ... 'norm_kwargs': {}, ... 'scale': radius * 60.0, ... }] optimizer = SCIONLight(parameter_groups)

For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

Source code in pytorch_optimizer/optimizer/scion.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
class SCIONLight(BaseOptimizer):
    r"""Memory-efficient variant of the Scion optimizer.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor. 1.0 - usual momentum.
        constraint (bool): whether to use a constraint SCG or not.
        norm_type (int): supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7.
            Please check LMONorm Enum class for the details.
        norm_kwargs (Optional[Dict]): arguments for the Norm.
        scale (float): scale factor. For Transformer block typical value is 50.0, and 3000.0 for others
            (e.g., Embeddings, LM head).
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): the optimizer uses decoupled weight decay as in AdamW.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): maximize the objective with respect to the params, instead of minimizing.

    Example:
        >>> radius = 50.0
        >>> parameter_groups = [{
        ...     'params': model.transformer.h.parameters(),
        ...     'norm_type': 'spectral',
        ...     'norm_kwargs': {},
        ...     'scale': radius,
        ... }, {
        ...     'params': model.lm_head.parameters(),
        ...     'norm_type': 'sign',
        ...     'norm_kwargs': {},
        ...     'scale': radius * 60.0,
        ... }]
        >>> optimizer = SCIONLight(parameter_groups)

        For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.1,
        constraint: bool = False,
        norm_type: int = LMONorm.AUTO,
        norm_kwargs: Optional[Dict] = None,
        scale: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, '(]')
        self.validate_positive(scale, 'scale')

        self.foreach = foreach
        self.maximize = maximize

        if norm_kwargs is None:
            norm_kwargs = {}

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'constraint': constraint,
            'norm_type': norm_type,
            'norm_kwargs': norm_kwargs,
            'scale': scale,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'foreach': foreach,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SCIONLight'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def init(self):
        for group in self.param_groups:
            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])
            for p in group['params']:
                norm.init(p)
                p.mul_(group['scale'])

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        norm: Norm,
    ) -> None:
        momentum = group['momentum']

        if self.maximize:
            torch._foreach_neg_(grads)

        updates = [norm.lmo(grad) for grad in grads]
        torch._foreach_mul_(updates, group['scale'])

        if group['constraint']:
            torch._foreach_mul_(params, 1.0 - group['lr'])

        if not group['constraint'] and group['weight_decay'] > 0.0:
            self.apply_weight_decay_foreach(
                params,
                grads=grads,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=False,
            )

        torch._foreach_add_(params, updates, alpha=-group['lr'])

        if momentum != 1.0:
            torch._foreach_mul_(grads, 1.0 - momentum)

    def _step_per_param(self, group: ParamGroup, norm: Norm) -> None:
        momentum = group['momentum']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            self.maximize_gradient(grad, maximize=self.maximize)

            update = norm.lmo(grad).mul_(group['scale'])

            if group['constraint']:
                p.mul_(1.0 - group['lr'])

            if not group['constraint'] and group['weight_decay'] > 0.0:
                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

            p.add_(update, alpha=-group['lr'])

            if momentum != 1.0:
                grad.mul_(1.0 - momentum)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])

            if self._can_use_foreach(group):
                params, grads, _ = self.collect_trainable_params(group, self.state)
                if params:
                    self._step_foreach(group, params, grads, norm)
            else:
                self._step_per_param(group, norm)

        return loss

SGDP

Bases: BaseOptimizer

SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
momentum float

Momentum factor.

0.0
dampening float

Dampening for momentum.

0.0
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
delta float

Threshold that determines whether a set of parameters is scale-invariant or not.

0.1
wd_ratio float

Relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
nesterov bool

Enables Nesterov momentum.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamp.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class SGDP(BaseOptimizer):
    """SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Momentum factor.
        dampening (float): Dampening for momentum.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        delta (float): Threshold that determines whether a set of parameters is scale-invariant or not.
        wd_ratio (float): Relative weight decay applied on scale-invariant parameters compared to that applied
            on scale-variant parameters.
        nesterov (bool): Enables Nesterov momentum.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        nesterov: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'momentum': momentum,
            'dampening': dampening,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDP'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]
            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                buf = state['momentum']

                p, grad, buf = self.view_as_real(p, grad, buf)

                buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                d_p = buf.clone()
                if group['nesterov']:
                    d_p = d_p.mul_(momentum).add_(grad)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    d_p, wd_ratio = projection(
                        p,
                        grad,
                        d_p,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio / (1.0 - momentum),
                )

                p.add_(d_p, alpha=-group['lr'])

        return loss

SGDSaI

Bases: BaseOptimizer

No More Adam: Learning Rate Scaling at Initialization is All You Need.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.01
momentum float

coefficients used for computing running averages of gradient.

0.9
weight_decay float

weight decay (L2 penalty).

0.01
weight_decouple bool

optimizer uses decoupled weight decay as in AdamW.

True
eps float

term added to denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
class SGDSaI(BaseOptimizer):
    """No More Adam: Learning Rate Scaling at Initialization is All You Need.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): coefficients used for computing running averages of gradient.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): optimizer uses decoupled weight decay as in AdamW.
        eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        momentum: float = 0.9,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.has_warmup: bool = False
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDSaI'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if group['momentum'] > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def warmup_step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                sigma = grad.std().nan_to_num_() if grad.ndim > 1 and grad.size(0) != 1 else 0
                grad_norm = grad.norm()

                g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm

                self.state[p]['gsnr'] = g_snr

        self.has_warmup = True

        return loss

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        if not self.has_warmup:
            self.warmup_step(closure)

        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            group['step'] += 1

            momentum: float = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
                else:
                    buf = grad

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(buf, alpha=-group['lr'] * state['gsnr'])

        return loss

SGDW

Bases: BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.0001
momentum float

momentum factor.

0.0
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

optimizer uses decoupled weight decay as in AdamW.

True
dampening float

dampening for momentum.

0.0
nesterov bool

enables Nesterov momentum.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class SGDW(BaseOptimizer):
    """Decoupled Weight Decay Regularization.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): optimizer uses decoupled weight decay as in AdamW.
        dampening (float): dampening for momentum.
        nesterov (bool): enables Nesterov momentum.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-4,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        dampening: float = 0.0,
        nesterov: bool = False,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'dampening': dampening,
            'nesterov': nesterov,
            'foreach': foreach,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDW'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = p.clone()

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        momentum_buffers: List[torch.Tensor],
    ) -> None:
        lr = group['lr']
        dampening = group['dampening']

        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=lr,
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=False,
        )

        torch._foreach_lerp_(momentum_buffers, grads, weight=1.0 - dampening)
        if group['nesterov']:
            torch._foreach_add_(grads, momentum_buffers, alpha=group['momentum'])

        torch._foreach_add_(params, momentum_buffers, alpha=-lr)

    def _step_per_param(self, group: ParamGroup) -> None:
        momentum = group['momentum']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            if momentum > 0.0:
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                if group['nesterov']:
                    grad.add_(buf, alpha=momentum)
                else:
                    grad = buf

            self.apply_weight_decay(
                p,
                grad=grad,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=False,
            )

            p.add_(grad, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['momentum_buffer']
                )
                if params:
                    self._step_foreach(group, params, grads, state_dict['momentum_buffer'])
            else:
                self._step_per_param(group)

        return loss

Shampoo

Bases: BaseOptimizer

Preconditioned Stochastic Tensor Optimization.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor.

0.0
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

fix weight decay.

False
preconditioning_compute_steps int

how often to compute the preconditioner, tuning memory and compute requirements.

1
matrix_eps float

term added to denominator to improve numerical stability.

1e-06
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/shampoo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class Shampoo(BaseOptimizer):
    """Preconditioned Stochastic Tensor Optimization.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): fix weight decay.
        preconditioning_compute_steps (int): how often to compute the preconditioner,
            tuning memory and compute requirements.
        matrix_eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        preconditioning_compute_steps: int = 1,
        matrix_eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.preconditioning_compute_steps = preconditioning_compute_steps
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'matrix_eps': matrix_eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Shampoo'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['momentum'] > 0.0:
                    state['momentum_buffer'] = grad.clone()

                for dim_id, dim in enumerate(grad.size()):
                    state[f'pre_cond_{dim_id}'] = group['matrix_eps'] * torch.eye(dim, out=grad.new(dim, dim))
                    state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                order: int = grad.ndimension()
                original_size: int = grad.size()
                for dim_id, dim in enumerate(grad.size()):
                    pre_cond, inv_pre_cond = state[f'pre_cond_{dim_id}'], state[f'inv_pre_cond_{dim_id}']

                    grad = grad.transpose_(0, dim_id).contiguous()
                    transposed_size = grad.size()

                    grad = grad.view(dim, -1)
                    grad_t = grad.t()

                    pre_cond.add_(grad @ grad_t)
                    if group['step'] % self.preconditioning_compute_steps == 0:
                        inv_pre_cond.copy_(compute_power_svd(pre_cond, order))

                    if dim_id == order - 1:
                        grad = grad_t @ inv_pre_cond
                        grad = grad.view(original_size)
                    else:
                        grad = inv_pre_cond @ grad
                        grad = grad.view(transposed_size)

                state['momentum_buffer'] = grad

                p.add_(grad, alpha=-group['lr'])

        return loss

SignSGD

Bases: BaseOptimizer

Compressed Optimisation for Non-Convex Problems.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.001
momentum float

momentum factor (0.0 = SignSGD, >0 = Signum).

0.9
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

optimizer uses decoupled weight decay as in AdamW.

True
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
class SignSGD(BaseOptimizer):
    """Compressed Optimisation for Non-Convex Problems.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        momentum (float): momentum factor (0.0 = SignSGD, >0 = Signum).
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): optimizer uses decoupled weight decay as in AdamW.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'beta', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'foreach': foreach,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SignSGD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if group['momentum'] > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False or group['momentum'] == 0.0:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        momentum_buffers: List[torch.Tensor],
    ) -> None:
        lr = group['lr']

        if self.maximize:
            torch._foreach_neg_(grads)

        torch._foreach_lerp_(momentum_buffers, grads, weight=1.0 - group['momentum'])

        updates = [buf.sign() for buf in momentum_buffers]
        torch._foreach_add_(params, updates, alpha=-lr)

    def _step_per_param(self, group: ParamGroup) -> None:
        momentum = group['momentum']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            if momentum > 0.0:
                buf = state['momentum_buffer']
                buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
            else:
                buf = grad

            p.add_(torch.sign(buf) if not torch.is_complex(buf) else torch.sgn(buf), alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['momentum_buffer']
                )
                if params:
                    self._step_foreach(group, params, grads, state_dict['momentum_buffer'])
            else:
                self._step_per_param(group)

        return loss

SimplifiedAdEMAMix

Bases: BaseOptimizer

Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.0001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.99, 0.95)
alpha float

Coefficient for mixing the current gradient and EMA.

0.0
beta1_warmup Optional[int]

Number of warmup steps used to increase beta1.

None
min_beta1 float

Minimum value of beta1 to start from.

0.9
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether to use decoupled weight decay as in AdamW.

True
fixed_decay bool

Apply fixed weight decay instead of adaptive.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ademamix.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class SimplifiedAdEMAMix(BaseOptimizer):
    """Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        alpha (float): Coefficient for mixing the current gradient and EMA.
        beta1_warmup (Optional[int]): Number of warmup steps used to increase beta1.
        min_beta1 (float): Minimum value of beta1 to start from.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether to use decoupled weight decay as in AdamW.
        fixed_decay (bool): Apply fixed weight decay instead of adaptive.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-4,
        betas: Betas = (0.99, 0.95),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        alpha: float = 0.0,
        beta1_warmup: Optional[int] = None,
        min_beta1: float = 0.9,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(alpha, 'alpha')
        self.validate_non_negative(min_beta1, 'min_beta1')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'beta1_warmup': beta1_warmup,
            'min_beta1': min_beta1,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SimplifiedAdEMAMix'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['num_sum'] = 0.0
                state['den_sum'] = 0.0

    @staticmethod
    def linear_hl_warmup_scheduler(step: int, beta_end: float, beta_start: float = 0.0, warmup: int = 1) -> float:
        def f(beta: float, eps: float = 1e-8) -> float:
            return math.log(0.5) / math.log(beta + eps) - 1.0

        def f_inv(t: float) -> float:
            return math.pow(0.5, 1.0 / (t + 1))

        if step < warmup:
            a: float = step / float(warmup)
            return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end))

        return beta_end

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            if group['beta1_warmup']:
                beta1 = self.linear_hl_warmup_scheduler(
                    group['step'], beta_end=beta1, beta_start=group['min_beta1'], warmup=group['beta1_warmup']
                )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                state['num_sum'] = beta1 * state['num_sum'] + 1.0
                state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(math.sqrt(state['den_sum']) * group['eps'])

                update = (group['alpha'] * grad + exp_avg).div_(de_nom).div_(math.sqrt(state['den_sum']))

                p.add_(update, alpha=-group['lr'])

        return loss

SM3

Bases: BaseOptimizer

Memory-Efficient Adaptive Optimization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.1
momentum float

Coefficient used to scale prior updates before adding. This drastically increases memory usage if momentum > 0.0. This is ignored if the parameter's gradient is sparse.

0.0
beta float

Coefficient used for exponential moving averages.

0.0
eps float

Term added to the denominator to improve numerical stability.

1e-30
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sm3.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class SM3(BaseOptimizer):
    r"""Memory-Efficient Adaptive Optimization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Coefficient used to scale prior updates before adding. This drastically increases
            memory usage if momentum > 0.0. This is ignored if the parameter's gradient is sparse.
        beta (float): Coefficient used for exponential moving averages.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        momentum: float = 0.0,
        beta: float = 0.0,
        eps: float = 1e-30,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'momentum': momentum, 'beta': beta, 'eps': eps}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SM3'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            grad = p.grad

            shape = grad.shape
            rank: int = len(shape)

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = torch.zeros_like(grad)

                if grad.is_sparse:
                    state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
                elif rank == 0:
                    state['accumulator_0'] = torch.zeros_like(grad)
                else:
                    for i in range(rank):
                        state[f'accumulator_{i}'] = torch.zeros(
                            [1] * i + [shape[i]] + [1] * (rank - 1 - i), dtype=grad.dtype, device=grad.device
                        )

    @staticmethod
    def make_sparse(grad: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        if grad._indices().dim() == 0 or values.dim() == 0:
            return grad.new().resize_as_(grad)
        return grad.new(grad._indices(), values, grad.size())

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            momentum, beta = group['momentum'], group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                shape = grad.shape
                rank: int = len(shape)

                state = self.state[p]

                if grad.is_sparse:
                    grad = grad.coalesce()

                    acc = state['accumulator_0']
                    update_values = torch.gather(acc, 0, grad._indices()[0])
                    if beta > 0.0:
                        update_values.mul_(beta)
                    update_values.addcmul_(grad._values(), grad._values(), value=1.0 - beta)

                    nu_max = reduce_max_except_dim(self.make_sparse(grad, update_values).to_dense(), 0).squeeze_()

                    if beta > 0.0:
                        torch.max(acc, nu_max, out=acc)
                    else:
                        acc.copy_(nu_max)

                    update_values.add_(group['eps']).rsqrt_().mul_(grad._values())

                    update = self.make_sparse(grad, update_values)
                else:
                    update = state['accumulator_0'].clone()
                    for i in range(1, rank):
                        update = torch.min(update, state[f'accumulator_{i}'])

                    if beta > 0.0:
                        update.mul_(beta)
                    update.addcmul_(grad, grad, value=1.0 - beta)

                    for i in range(rank):
                        acc = state[f'accumulator_{i}']
                        nu_max = reduce_max_except_dim(update, i)
                        if beta > 0.0:
                            torch.max(acc, nu_max, out=acc)
                        else:
                            acc.copy_(nu_max)

                    update.add_(group['eps']).rsqrt_().mul_(grad)

                    if momentum > 0.0:
                        m = state['momentum_buffer']
                        m.mul_(momentum).add_(update, alpha=1.0 - momentum)
                        update = m

                p.add_(update, alpha=-group['lr'])

        return loss

SOAP

Bases: BaseOptimizer

Improving and Stabilizing Shampoo using Adam.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.003
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.95, 0.95)
shampoo_beta Optional[float]

If not None, use this beta for the pre-conditioner (L and R in paper, state['GG'] below) moving average instead of betas.

None
weight_decay float

Weight decay (L2 penalty).

0.01
precondition_frequency int

How often to update the pre-conditioner.

10
max_precondition_dim int

Maximum dimension of the pre-conditioner. Set to 10000, so that we exclude most common vocab sizes while including layers.

10000
merge_dims bool

Whether to merge dimensions of the pre-conditioner.

False
precondition_1d bool

Whether to precondition 1D gradients.

False
correct_bias bool

Whether to correct bias in Adam.

True
normalize_gradient bool

Whether to normalize the gradients.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/soap.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class SOAP(BaseOptimizer):
    """Improving and Stabilizing Shampoo using Adam.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        shampoo_beta (Optional[float]): If not None, use this beta for the pre-conditioner
            (L and R in paper, state['GG'] below) moving average instead of betas.
        weight_decay (float): Weight decay (L2 penalty).
        precondition_frequency (int): How often to update the pre-conditioner.
        max_precondition_dim (int): Maximum dimension of the pre-conditioner. Set to 10000, so that we exclude most
            common vocab sizes while including layers.
        merge_dims (bool): Whether to merge dimensions of the pre-conditioner.
        precondition_1d (bool): Whether to precondition 1D gradients.
        correct_bias (bool): Whether to correct bias in Adam.
        normalize_gradient (bool): Whether to normalize the gradients.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 3e-3,
        betas: Betas = (0.95, 0.95),
        shampoo_beta: Optional[float] = None,
        weight_decay: float = 1e-2,
        precondition_frequency: int = 10,
        max_precondition_dim: int = 10000,
        merge_dims: bool = False,
        precondition_1d: bool = False,
        correct_bias: bool = True,
        normalize_gradient: bool = False,
        data_format: DataFormat = 'channels_first',
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(shampoo_beta, 'shampoo_beta')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(precondition_frequency, 'precondition_frequency')
        self.validate_positive(max_precondition_dim, 'max_precondition_dim')
        self.validate_non_negative(eps, 'eps')

        self.data_format = data_format
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'shampoo_beta': shampoo_beta,
            'weight_decay': weight_decay,
            'precondition_frequency': precondition_frequency,
            'max_precondition_dim': max_precondition_dim,
            'merge_dims': merge_dims,
            'precondition_1d': precondition_1d,
            'correct_bias': correct_bias,
            'normalize_gradient': normalize_gradient,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SOAP'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        _, beta2 = group['betas']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)

                self.init_pre_conditioner(
                    grad,
                    state,
                    precondition_frequency=group['precondition_frequency'],
                    shampoo_beta=group['shampoo_beta'] if group['shampoo_beta'] is not None else beta2,
                    max_precondition_dim=group['max_precondition_dim'],
                    precondition_1d=group['precondition_1d'],
                    merge_dims=group['merge_dims'],
                )

                self.update_pre_conditioner(
                    grad,
                    state,
                    step=group['step'],
                    max_precondition_dim=group['max_precondition_dim'],
                    precondition_1d=group['precondition_1d'],
                    merge_dims=group['merge_dims'],
                )

    def project(
        self,
        grad: torch.Tensor,
        state,
        merge_dims: bool = False,
        max_precondition_dim: int = 10000,
        project_type: str = 'forward',
    ) -> torch.Tensor:
        original_shape = grad.shape
        permuted_shape = original_shape

        do_permute: bool = self.data_format == 'channels_last' and len(original_shape) == 4

        if merge_dims:
            if do_permute:
                permuted_shape = grad.permute(0, 3, 1, 2).shape

            grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(grad, mat, dims=[[0], [0 if project_type == 'forward' else 1]])
            else:
                grad = grad.permute([*list(range(1, len(grad.shape))), 0])

        if merge_dims:
            grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1) if do_permute else grad.reshape(original_shape)

        return grad

    @staticmethod
    def get_orthogonal_matrix(mat: torch.Tensor) -> List[torch.Tensor]:
        matrices: List = []
        for m in mat:
            if len(m) == 0:
                matrices.append([])
                continue

            try:
                _, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
            except Exception:  # pragma: no cover
                _, q = torch.linalg.eigh(
                    m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=torch.float64)
                )
                q = q.to(m.dtype)

            q = torch.flip(q, dims=[1])

            matrices.append(q)

        return matrices

    def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, merge_dims: bool = False):
        """Compute the eigen-bases of the pre-conditioner using one round of power iteration."""
        original_shape = state['exp_avg_sq'].shape
        permuted_shape = original_shape
        if self.data_format == 'channels_last' and len(original_shape) == 4:
            permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape

        exp_avg_sq = state['exp_avg_sq']
        if merge_dims:
            exp_avg_sq = exp_avg_sq.reshape(merge_small_dims(exp_avg_sq.size(), max_precondition_dim))

        matrices = []
        for ind, (m, o) in enumerate(zip(state['GG'], state['Q'])):
            if len(m) == 0:
                matrices.append([])
                continue

            est_eig = torch.diag(o.T @ m @ o)
            sort_idx = torch.argsort(est_eig, descending=True)
            exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)

            power_iter = m @ o[:, sort_idx]

            # Compute QR decomposition
            # We cast to float32 because:
            #  - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
            #  - the correctness / numerical stability of the Q orthogonality is important for the stability
            #    of the optimizer
            q, _ = torch.linalg.qr(power_iter.to(torch.float32))
            q = q.to(power_iter.dtype)

            matrices.append(q)

        if merge_dims:
            if self.data_format == 'channels_last' and len(original_shape) == 4:
                exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                exp_avg_sq = exp_avg_sq.reshape(original_shape)

        state['exp_avg_sq'] = exp_avg_sq

        return matrices

    @staticmethod
    def init_pre_conditioner(
        grad,
        state,
        precondition_frequency: int = 10,
        shampoo_beta: float = 0.95,
        max_precondition_dim: int = 10000,
        precondition_1d: bool = False,
        merge_dims: bool = False,
    ) -> None:
        state['GG'] = []
        if grad.dim() == 1:
            if not precondition_1d or grad.shape[0] > max_precondition_dim:
                state['GG'].append([])
            else:
                state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
        else:
            if merge_dims:
                grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

            for sh in grad.shape:
                if sh > max_precondition_dim:
                    state['GG'].append([])
                else:
                    state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))

        state['Q'] = None
        state['precondition_frequency'] = precondition_frequency
        state['shampoo_beta'] = shampoo_beta

    def update_pre_conditioner(
        self,
        grad,
        state,
        step: int,
        max_precondition_dim: int = 10000,
        precondition_1d: bool = False,
        merge_dims: bool = False,
    ) -> None:
        if grad.dim() == 1:
            if precondition_1d and grad.shape[0] <= max_precondition_dim:
                state['GG'][0].lerp_(
                    (grad.unsqueeze(1) @ grad.unsqueeze(0)).to(state['GG'][0].dtype),
                    weight=1.0 - state['shampoo_beta'],
                )
        else:
            if merge_dims:
                grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

            for idx, dim in enumerate(grad.shape):
                if dim <= max_precondition_dim:
                    outer_product = torch.tensordot(
                        grad,
                        grad,
                        dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                    )

                    state['GG'][idx].lerp_(
                        outer_product.to(state['GG'][idx].dtype), weight=1.0 - state['shampoo_beta']
                    )

        if state['Q'] is None:
            state['Q'] = self.get_orthogonal_matrix(state['GG'])

        if step > 0 and step % state['precondition_frequency'] == 0:
            state['Q'] = self.get_orthogonal_matrix_qr(state, max_precondition_dim, merge_dims)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if group['step'] == 1:
                continue

            beta1, beta2 = group['betas']

            step_size: float = group['lr']
            if group['correct_bias']:
                bias_correction1: float = self.debias(beta1, group['step'])
                bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

                step_size *= bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_projected = self.project(
                    grad, state, merge_dims=group['merge_dims'], max_precondition_dim=group['max_precondition_dim']
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                exp_avg_projected = self.project(
                    exp_avg, state, merge_dims=group['merge_dims'], max_precondition_dim=group['max_precondition_dim']
                )

                norm_grad = self.project(
                    exp_avg_projected / de_nom,
                    state,
                    merge_dims=group['merge_dims'],
                    max_precondition_dim=group['max_precondition_dim'],
                    project_type='backward',
                )

                if group['normalize_gradient']:
                    norm_grad.div_(torch.mean(norm_grad.square()).sqrt_().add_(group['eps']))

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                self.update_pre_conditioner(
                    grad,
                    state,
                    step=group['step'],
                    max_precondition_dim=group['max_precondition_dim'],
                    merge_dims=group['merge_dims'],
                    precondition_1d=group['precondition_1d'],
                )

        return loss

get_orthogonal_matrix_qr(state, max_precondition_dim=10000, merge_dims=False)

Compute the eigen-bases of the pre-conditioner using one round of power iteration.

Source code in pytorch_optimizer/optimizer/soap.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, merge_dims: bool = False):
    """Compute the eigen-bases of the pre-conditioner using one round of power iteration."""
    original_shape = state['exp_avg_sq'].shape
    permuted_shape = original_shape
    if self.data_format == 'channels_last' and len(original_shape) == 4:
        permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape

    exp_avg_sq = state['exp_avg_sq']
    if merge_dims:
        exp_avg_sq = exp_avg_sq.reshape(merge_small_dims(exp_avg_sq.size(), max_precondition_dim))

    matrices = []
    for ind, (m, o) in enumerate(zip(state['GG'], state['Q'])):
        if len(m) == 0:
            matrices.append([])
            continue

        est_eig = torch.diag(o.T @ m @ o)
        sort_idx = torch.argsort(est_eig, descending=True)
        exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)

        power_iter = m @ o[:, sort_idx]

        # Compute QR decomposition
        # We cast to float32 because:
        #  - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
        #  - the correctness / numerical stability of the Q orthogonality is important for the stability
        #    of the optimizer
        q, _ = torch.linalg.qr(power_iter.to(torch.float32))
        q = q.to(power_iter.dtype)

        matrices.append(q)

    if merge_dims:
        if self.data_format == 'channels_last' and len(original_shape) == 4:
            exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
        else:
            exp_avg_sq = exp_avg_sq.reshape(original_shape)

    state['exp_avg_sq'] = exp_avg_sq

    return matrices

SophiaH

Bases: BaseOptimizer

Second-order Clipped Stochastic Optimization.

Requires loss.backward(create_graph=True) in order to calculate hessians.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.06
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.96, 0.99)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Fix weight decay.

False
p float

Clip effective (applied) gradient (p).

0.01
update_period int

Number of steps after which to apply Hessian approximation.

10
num_samples int

Times to sample z for the approximation of the Hessian trace.

1
hessian_distribution HutchinsonG

HutchinsonG. Type of distribution to initialize Hessian.

'gaussian'
eps float

Term added to the denominator to improve numerical stability.

1e-12
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sophia.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class SophiaH(BaseOptimizer):
    r"""Second-order Clipped Stochastic Optimization.

    Requires `loss.backward(create_graph=True)` in order to calculate hessians.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Fix weight decay.
        p (float): Clip effective (applied) gradient (p).
        update_period (int): Number of steps after which to apply Hessian approximation.
        num_samples (int): Times to sample z for the approximation of the Hessian trace.
        hessian_distribution: HutchinsonG. Type of distribution to initialize Hessian.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 6e-2,
        betas: Betas = (0.96, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        p: float = 1e-2,
        update_period: int = 10,
        num_samples: int = 1,
        hessian_distribution: HutchinsonG = 'gaussian',
        eps: float = 1e-12,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(p, 'p (gradient clip)')
        self.validate_step(update_period, 'update_period')
        self.validate_positive(num_samples, 'num_samples')
        self.validate_options(hessian_distribution, 'hessian_distribution', ['gaussian', 'rademacher'])
        self.validate_non_negative(eps, 'eps')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'p': p,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SophiaH'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)
                state['hessian_moment'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Closure = None, hessian: Optional[List[torch.Tensor]] = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                momentum, hessian_moment = state['momentum'], state['hessian_moment']
                momentum.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    hessian_moment.mul_(beta2).add_(state['hessian'], alpha=1.0 - beta2)

                update = (momentum / torch.clip(hessian_moment, min=group['eps'])).clamp_(-group['p'], group['p'])

                p.add_(update, alpha=-group['lr'])

        return loss

SPAM

Bases: BaseOptimizer

Spike-Aware Adam with Momentum Reset for Stable LLM Training.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
density float

Density parameter. Only used for 2D parameters (e.g., Linear).

1.0
weight_decay float

Weight decay (L2 penalty).

0.0
warmup_epoch int

Number of epochs to warm up. Defaults to 50.

50
threshold int

Threshold for gradient masking. Defaults to 5000.

5000
grad_accu_steps int

Gradient accumulation steps before threshold-based masking applies. Defaults to 20.

20
update_proj_gap int

Update projection gap.

500
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/spam.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
class SPAM(BaseOptimizer):
    r"""Spike-Aware Adam with Momentum Reset for Stable LLM Training.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        density (float): Density parameter. Only used for 2D parameters (e.g., Linear).
        weight_decay (float): Weight decay (L2 penalty).
        warmup_epoch (int): Number of epochs to warm up. Defaults to 50.
        threshold (int): Threshold for gradient masking. Defaults to 5000.
        grad_accu_steps (int): Gradient accumulation steps before threshold-based masking applies. Defaults to 20.
        update_proj_gap (int): Update projection gap.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        density: float = 1.0,
        weight_decay: float = 0.0,
        warmup_epoch: int = 50,
        threshold: int = 5000,
        grad_accu_steps: int = 20,
        update_proj_gap: int = 500,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(warmup_epoch, 'warmup_epoch')
        self.validate_non_negative(density, 'density')
        self.validate_non_negative(threshold, 'threshold')
        self.validate_non_negative(grad_accu_steps, 'grad_accu_steps')
        self.validate_positive(update_proj_gap, 'update_proj_gap')
        self.validate_non_negative(eps, 'eps')

        self.density = density
        self.warmup_epoch = warmup_epoch
        self.threshold = threshold
        self.grad_accu_steps = grad_accu_steps
        self.update_proj_gap = update_proj_gap
        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

        self.warmup = CosineDecay(0.99, self.warmup_epoch)

        self.init_masks()

        self.state['total_step'] = 0
        self.state['current_step'] = self.warmup_epoch + 1

    @staticmethod
    def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
        r"""Create an (m x n) boolean tensor with `density` fraction of True entries.

        :param m: int. number of rows.
        :param n: int. number of columns.
        :param density: float. fraction of True entries. 1.0 means all True.
        :param device: torch.device. device.
        """
        total_elements: int = m * n
        non_zero_count: int = int(density * total_elements)

        tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)

        if non_zero_count > 0:
            tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True

        return tensor.view(m, n)

    def update_mask_random(self, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
        r"""Update a random mask.

        Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
        the overlap region.

        :param p: torch.Tensor. parameter to which the mask is applied.
        :param old_mask: torch.Tensor. previous binary mask.
        """
        new_mask: torch.Tensor = torch.rand_like(p) < self.density

        exp_avg = torch.zeros_like(p[new_mask])
        exp_avg_sq = torch.zeros_like(p[new_mask])

        intersection_mask = new_mask & old_mask
        new_intersection_indices = intersection_mask[new_mask]
        old_intersection_indices = intersection_mask[old_mask]

        state = self.state[p]
        exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices]
        exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices]

        state['exp_avg'] = exp_avg
        state['exp_avg_sq'] = exp_avg_sq

        return new_mask

    def update_masks(self) -> None:
        r"""Update masks in each parameter group that has 'density'.

        The new mask is selected randomly, and the overlap ratio with the old mask is printed.
        """
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.dim() == 2 and 'mask' in state:
                    state['mask'] = self.update_mask_random(p, state['mask'])
                    p.mask = state['mask']

    def init_masks(self) -> None:
        r"""Initialize random masks for each parameter group that has 'density'."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.dim() == 2 and 'mask' not in state:
                    state['mask'] = self.initialize_random_rank_boolean_tensor(
                        m=p.shape[0],
                        n=p.shape[1],
                        density=self.density,
                        device=p.device,
                    )

    def __str__(self) -> str:
        return 'SPAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        scale_factor: float = 1.0 - self.warmup.get_death_rate(self.state['current_step'])

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if 'mask' in state:
                    grad = grad[state['mask']]

                if ('exp_avg' not in state) or (self.state['total_step'] + 1) % self.update_proj_gap == 0:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                if self.threshold != 0:
                    current_step: int = self.state['total_step'] + 1
                    if current_step >= self.grad_accu_steps and (
                        self.update_proj_gap == 0 or current_step % self.update_proj_gap >= self.grad_accu_steps
                    ):
                        mask = grad.pow(2) > (self.threshold * exp_avg_sq)
                        grad[mask].sign_().mul_(torch.sqrt(exp_avg_sq[mask] * self.threshold))

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if 'mask' in state:
                    grad_full = torch.zeros_like(p.grad)
                    grad_full[state['mask']] = exp_avg / de_nom
                    p.add_(grad_full, alpha=-step_size * scale_factor)
                else:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size * scale_factor)

                self.apply_weight_decay(
                    p[state['mask']] if 'mask' in state else p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        self.state['total_step'] += 1
        self.state['current_step'] += 1

        if (self.state['total_step'] != 0) and (self.state['total_step'] + 1) % self.update_proj_gap == 0:
            self.update_masks()
            self.state['current_step'] = 0
            self.warmup = CosineDecay(0.99, self.warmup_epoch)

        return loss

init_masks()

Initialize random masks for each parameter group that has 'density'.

Source code in pytorch_optimizer/optimizer/spam.py
175
176
177
178
179
180
181
182
183
184
185
186
def init_masks(self) -> None:
    r"""Initialize random masks for each parameter group that has 'density'."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            if p.dim() == 2 and 'mask' not in state:
                state['mask'] = self.initialize_random_rank_boolean_tensor(
                    m=p.shape[0],
                    n=p.shape[1],
                    density=self.density,
                    device=p.device,
                )

initialize_random_rank_boolean_tensor(m, n, density, device) staticmethod

Create an (m x n) boolean tensor with density fraction of True entries.

:param m: int. number of rows. :param n: int. number of columns. :param density: float. fraction of True entries. 1.0 means all True. :param device: torch.device. device.

Source code in pytorch_optimizer/optimizer/spam.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@staticmethod
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
    r"""Create an (m x n) boolean tensor with `density` fraction of True entries.

    :param m: int. number of rows.
    :param n: int. number of columns.
    :param density: float. fraction of True entries. 1.0 means all True.
    :param device: torch.device. device.
    """
    total_elements: int = m * n
    non_zero_count: int = int(density * total_elements)

    tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)

    if non_zero_count > 0:
        tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True

    return tensor.view(m, n)

update_mask_random(p, old_mask)

Update a random mask.

Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for the overlap region.

:param p: torch.Tensor. parameter to which the mask is applied. :param old_mask: torch.Tensor. previous binary mask.

Source code in pytorch_optimizer/optimizer/spam.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def update_mask_random(self, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
    r"""Update a random mask.

    Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
    the overlap region.

    :param p: torch.Tensor. parameter to which the mask is applied.
    :param old_mask: torch.Tensor. previous binary mask.
    """
    new_mask: torch.Tensor = torch.rand_like(p) < self.density

    exp_avg = torch.zeros_like(p[new_mask])
    exp_avg_sq = torch.zeros_like(p[new_mask])

    intersection_mask = new_mask & old_mask
    new_intersection_indices = intersection_mask[new_mask]
    old_intersection_indices = intersection_mask[old_mask]

    state = self.state[p]
    exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices]
    exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices]

    state['exp_avg'] = exp_avg
    state['exp_avg_sq'] = exp_avg_sq

    return new_mask

update_masks()

Update masks in each parameter group that has 'density'.

The new mask is selected randomly, and the overlap ratio with the old mask is printed.

Source code in pytorch_optimizer/optimizer/spam.py
163
164
165
166
167
168
169
170
171
172
173
def update_masks(self) -> None:
    r"""Update masks in each parameter group that has 'density'.

    The new mask is selected randomly, and the overlap ratio with the old mask is printed.
    """
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            if p.dim() == 2 and 'mask' in state:
                state['mask'] = self.update_mask_random(p, state['mask'])
                p.mask = state['mask']

SpectralSphere

Bases: BaseOptimizer

Controlled LLM Training on Spectral Sphere.

This optimizer constrains weight matrices to lie on a spectral sphere of fixed radius R, where ||W||_2 = R. The optimization proceeds by:

  1. Power iteration to compute spectral norm sigma and top singular vectors (u, v)
  2. Retraction to spectral sphere: W ← (R / sigma) * W
  3. Form Θ = u @ v^T
  4. Solve for Lagrange multiplier lambda: <Θ, msign(M + lambdaΘ)> = 0
  5. Compute update direction: Φ = msign(M + lambdaΘ)
  6. Update: W ← W - lr * Φ

The key insight is that the retraction step at the end of iteration t is equivalent to the retraction at the beginning of iteration t+1. This allows us to unify the power iteration for both retraction and Theta computation in a single efficient step.

References
  • Spectral MuP: Spectral Control of Feature Learning
  • Modular Duality in Deep Learning. arXiv:2410.21265 (2024).

Parameters:

Name Type Description Default
params ParamsT

The parameters to be optimized by Muon.

required
lr float

Learning rate.

0.0003
momentum float

The momentum used by the internal SGD.

0.9
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

The optimizer uses decoupled weight decay as in AdamW.

True
nesterov bool

Whether to use nesterov momentum.

True
power_iteration_steps int

Number of power iteration steps for spectral norm computation.

10
msign_steps int

Number of Newton-Schulz iterations for msign (uses Polar-Express).

5
solver_tolerance_f float

Function value tolerance for solver.

1e-08
solver_max_iterations int

Maximum iterations for solver.

100
maximize bool

Maximize the objective with respect to the params, instead of minimizing.

False
Example

from pytorch_optimizer import SpectralSphere

hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]

param_groups = [ dict(params=hidden_weights, lr=0.02, weight_decay=0.01), ]

optimizer = SpectralSphere(param_groups) ...

Source code in pytorch_optimizer/optimizer/sso.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class SpectralSphere(BaseOptimizer):
    """Controlled LLM Training on Spectral Sphere.

    This optimizer constrains weight matrices to lie on a spectral sphere of fixed radius R,
    where ||W||_2 = R. The optimization proceeds by:

    1. Power iteration to compute spectral norm sigma and top singular vectors (u, v)
    2. Retraction to spectral sphere: W ← (R / sigma) * W
    3. Form Θ = u @ v^T
    4. Solve for Lagrange multiplier lambda: <Θ, msign(M + lambdaΘ)> = 0
    5. Compute update direction: Φ = msign(M + lambdaΘ)
    6. Update: W ← W - lr * Φ

    The key insight is that the retraction step at the end of iteration t is equivalent to
    the retraction at the beginning of iteration t+1. This allows us to unify the power
    iteration for both retraction and Theta computation in a single efficient step.

    References:
        - Spectral MuP: Spectral Control of Feature Learning
        - Modular Duality in Deep Learning. arXiv:2410.21265 (2024).

    Args:
        params (ParamsT): The parameters to be optimized by Muon.
        lr (float): Learning rate.
        momentum (float): The momentum used by the internal SGD.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): The optimizer uses decoupled weight decay as in AdamW.
        nesterov (bool): Whether to use nesterov momentum.
        power_iteration_steps (int): Number of power iteration steps for spectral norm computation.
        msign_steps (int): Number of Newton-Schulz iterations for msign (uses Polar-Express).
        solver_tolerance_f (float): Function value tolerance for solver.
        solver_max_iterations (int): Maximum iterations for solver.
        maximize (bool): Maximize the objective with respect to the params, instead of minimizing.

    Example:
        from pytorch_optimizer import SpectralSphere

        hidden_weights = [p for p in model.body.parameters() if p.ndim >= 2]

        param_groups = [
            dict(params=hidden_weights, lr=0.02, weight_decay=0.01),
        ]

        optimizer = SpectralSphere(param_groups)
        ...

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 3e-4,
        momentum: float = 0.9,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        nesterov: bool = True,
        power_iteration_steps: int = 10,
        msign_steps: int = 5,
        solver_tolerance_f: float = 1e-8,
        solver_max_iterations: int = 100,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_positive(power_iteration_steps, 'power_iteration_steps')
        self.validate_positive(msign_steps, 'msign_steps')

        self.power_iteration_steps = power_iteration_steps
        self.msign_steps = msign_steps
        self.solver_tolerance_f = solver_tolerance_f
        self.solver_max_iterations = solver_max_iterations

        self.maximize = maximize

        defaults = {
            'lr': lr,
            'momentum': momentum,
            'nesterov': nesterov,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SpectralSphere'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            if p.dim() != 2:
                raise ValueError(f'{self} only supports 2D parameters')

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'momentum_buffer' not in state:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                buf = state['momentum_buffer']
                buf.lerp_(grad, weight=1.0 - group['momentum'])

                update = grad.lerp_(buf, weight=group['momentum']) if group['nesterov'] else buf

                update = compute_spectral_ball_update(
                    p,
                    momentum=update,
                    power_iteration_steps=self.power_iteration_steps,
                    msign_steps=self.msign_steps,
                    solver_tolerance_f=self.solver_tolerance_f,
                    solver_max_iterations=self.solver_max_iterations,
                )

                p.add_(update, alpha=-group['lr'])

        return loss

SPlus

Bases: BaseOptimizer

A Stable Whitening Optimizer for Efficient Neural Network Training.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.1
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

Whether the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Whether to fix weight decay.

False
ema_rate float

Exponential moving average decay rate.

0.999
inverse_steps int

Number of steps to perform inverse.

100
nonstandard_constant float

Scale factor for the learning rate in case of a non-linear layer.

0.001
max_dim int

Maximum number of dimensions to perform the operation on.

10000
eps float

Term added to the denominator to improve numerical stability.

1e-30
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/splus.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class SPlus(BaseOptimizer):
    """A Stable Whitening Optimizer for Efficient Neural Network Training.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        ema_rate (float): Exponential moving average decay rate.
        inverse_steps (int): Number of steps to perform inverse.
        nonstandard_constant (float): Scale factor for the learning rate in case of a non-linear layer.
        max_dim (int): Maximum number of dimensions to perform the operation on.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ema_rate: float = 0.999,
        inverse_steps: int = 100,
        nonstandard_constant: float = 1e-3,
        max_dim: int = 10000,
        eps: float = 1e-30,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(ema_rate, 'ema_rate', 0.0, 1.0)
        self.validate_positive(inverse_steps, 'inverse_steps')
        self.validate_positive(max_dim, 'max_dim')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ema_rate': ema_rate,
            'inverse_steps': inverse_steps,
            'max_dim': max_dim,
            'nonstandard_constant': nonstandard_constant,
            'eps': eps,
            'train_mode': True,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SPlus'

    @torch.no_grad()
    def eval(self):
        for group in self.param_groups:
            if group.get('train_mode'):
                for p in group['params']:
                    state = self.state[p]
                    state['param_buffer'] = p.clone()
                    p.lerp_(state['ema'], weight=1.0).mul_(1.0 / (1.0 - group['ema_rate'] ** group['step']))
                group['train_mode'] = False

    @torch.no_grad()
    def train(self):
        for group in self.param_groups:
            if 'train_mode' in group and not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'param_buffer' in state:
                        p.lerp_(state['param_buffer'], weight=1.0)
                        del state['param_buffer']
                group['train_mode'] = True

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(p)
                state['ema'] = torch.zeros_like(p)
                if len(p.shape) == 2:
                    state['sides'] = [
                        torch.zeros((d, d), device=p.device, dtype=p.dtype) if d < group['max_dim'] else None
                        for d in p.shape
                    ]
                    state['q_sides'] = [
                        torch.eye(d, device=p.device, dtype=p.dtype) if d < group['max_dim'] else None for d in p.shape
                    ]

    @staticmethod
    def get_scaled_lr(shape: Tuple[int, int], lr: float, nonstandard_constant: float, max_dim: int = 10000) -> float:
        scale: float = (
            nonstandard_constant
            if len(shape) != 2 or shape[0] > max_dim or shape[1] > max_dim
            else 2.0 / (shape[0] + shape[1])
        )
        return lr * scale

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                scaled_lr: float = self.get_scaled_lr(
                    p.shape, group['lr'], group['nonstandard_constant'], group['max_dim']
                )

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=scaled_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                m, ema = state['momentum'], state['ema']
                m.lerp_(grad, weight=1.0 - beta1)

                if len(p.shape) == 2:
                    sides, q_sides = state['sides'], state['q_sides']

                    m = q_sides[0].T @ m if q_sides[0] is not None else m
                    m = m @ q_sides[1] if q_sides[1] is not None else m

                    if sides[0] is not None:
                        torch.lerp(sides[0], grad @ grad.T, weight=1.0 - beta2, out=sides[0])

                    if sides[1] is not None:
                        torch.lerp(sides[1], grad.T @ grad, weight=1.0 - beta2, out=sides[1])

                    update = torch.sign(m)

                    if q_sides[0] is not None:
                        update = q_sides[0] @ update

                    if q_sides[1] is not None:
                        update = update @ q_sides[1].T

                    if group['step'] == 1 or group['step'] % group['inverse_steps'] == 0:
                        if sides[0] is not None:
                            _, eig_vecs = torch.linalg.eigh(
                                sides[0].float() + torch.eye(sides[0].shape[0], device=p.device).mul_(group['eps'])
                            )
                            state['q_sides'][0] = eig_vecs.to(sides[0].dtype)
                        if sides[1] is not None:
                            _, eig_vecs = torch.linalg.eigh(
                                sides[1].float() + torch.eye(sides[1].shape[0], device=p.device).mul_(group['eps'])
                            )
                            state['q_sides'][1] = eig_vecs.to(sides[1].dtype)
                else:
                    update = torch.sign(m)

                p.add_(update, alpha=-scaled_lr)

                ema.lerp_(p, weight=1.0 - group['ema_rate'])

        return loss

SRMM

Bases: BaseOptimizer

Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
beta float

Adaptivity weight.

0.5
memory_length Optional[int]

Internal memory length for moving average. None for no refreshing.

100
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/srmm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class SRMM(BaseOptimizer):
    """Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): Adaptivity weight.
        memory_length (Optional[int]): Internal memory length for moving average. None for no refreshing.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 0.01,
        beta: float = 0.5,
        memory_length: Optional[int] = 100,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: Defaults = {'lr': lr, 'beta': beta, 'memory_length': memory_length}

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'SRMM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['mov_avg_grad'] = torch.zeros_like(grad)
                state['mov_avg_param'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            w_t: float = (
                (group['step'] % (group['memory_length'] if group['memory_length'] is not None else 1)) + 1
            ) ** -group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param']

                mov_avg_grad.mul_(1.0 - w_t).add_(grad, alpha=w_t)
                mov_avg_param.mul_(1.0 - w_t).add_(p, alpha=w_t)

                mov_avg_param.add_(mov_avg_grad, alpha=-group['lr'])

                p.copy_(mov_avg_param)

        return loss

StableAdamW

Bases: BaseOptimizer

Stable and low-precision training for large-scale vision-language models.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.99)
kahan_sum bool

Enables Kahan summation for more accurate parameter updates when training in low precision (float16 or bfloat16).

True
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

Decoupled weight decay.

True
eps float

Term added to the denominator to improve numerical stability.

1e-08
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the parameters, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamw.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
class StableAdamW(BaseOptimizer):
    """Stable and low-precision training for large-scale vision-language models.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        kahan_sum (bool): Enables Kahan summation for more accurate parameter updates when training in low precision
            (float16 or bfloat16).
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Decoupled weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the parameters, instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.99),
        kahan_sum: bool = True,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.foreach = foreach
        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'kahan_sum': kahan_sum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
            'foreach': foreach,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'StableAdamW'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                state['kahan_comp'] = (
                    torch.zeros_like(p)
                    if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16})
                    else None
                )

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
        exp_avg_sqs: List[torch.Tensor],
        kahan_comps: List[torch.Tensor],
    ) -> None:
        beta1, beta2 = group['betas']
        eps = group['eps']
        lr = group['lr']

        beta1_comp: float = 1.0 - self.debias_beta(beta1, group['step'])
        beta2_hat: float = self.debias_beta(beta2, group['step'])

        eps_p2: float = math.pow(eps, 2)

        if self.maximize:
            torch._foreach_neg_(grads)

        step_sizes: List[float] = [
            -lr / self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)
            for grad, exp_avg_sq in zip(grads, exp_avg_sqs)
        ]

        if group['weight_decay'] != 0.0 and group['weight_decouple']:
            wd_step_sizes = [1.0 + group['weight_decay'] * step_size for step_size in step_sizes]
            torch._foreach_mul_(params, wd_step_sizes)

        torch._foreach_lerp_(exp_avgs, grads, weight=beta1_comp)

        torch._foreach_mul_(exp_avg_sqs, beta2_hat)
        torch._foreach_addcmul_(exp_avg_sqs, grads, grads, value=1.0 - beta2_hat)

        de_noms = torch._foreach_sqrt(exp_avg_sqs)
        torch._foreach_add_(de_noms, eps)

        if group['kahan_sum'] and params[0].dtype in (torch.float16, torch.bfloat16):
            de_noms = torch._foreach_sqrt(exp_avg_sqs)
            torch._foreach_add_(de_noms, group['eps'])

            torch._foreach_addcdiv_(kahan_comps, exp_avgs, de_noms, step_sizes)

            with torch.no_grad():
                torch._foreach_copy_(grads, params)

            torch._foreach_add_(params, kahan_comps)

            torch._foreach_sub_(grads, params)
            torch._foreach_add_(kahan_comps, grads)
        else:
            torch._foreach_addcdiv_(params, exp_avgs, de_noms, step_sizes)

    def _step_per_param(self, group: ParamGroup) -> None:
        beta1, beta2 = group['betas']

        beta1_comp: float = 1.0 - self.debias_beta(beta1, group['step'])
        beta2_hat: float = self.debias_beta(beta2, group['step'])

        eps_p2: float = math.pow(group['eps'], 2)

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            state = self.state[p]

            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

            p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

            exp_avg.lerp_(grad, weight=beta1_comp)
            exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)

            lr: float = group['lr'] / self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)

            self.apply_weight_decay(
                p,
                grad=grad,
                lr=lr,
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=False,
            )

            if group['kahan_sum'] and p.dtype in (torch.float16, torch.bfloat16):
                kahan_comp = state['kahan_comp']
                kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

                grad.copy_(p.detach())
                p.add_(kahan_comp)

                kahan_comp.add_(grad.sub_(p))
            else:
                p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(
                    group, self.state, state_keys=['exp_avg', 'exp_avg_sq', 'kahan_comp']
                )
                if params:
                    self._step_foreach(
                        group,
                        params,
                        grads,
                        state_dict['exp_avg'],
                        state_dict['exp_avg_sq'],
                        state_dict['kahan_comp'],
                    )
            else:
                self._step_per_param(group)

        return loss

StableSPAM

Bases: BaseOptimizer

How to Train in 4-Bit More Stably than 16-Bit Adam.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
gamma1 float

Gamma1 parameter.

0.7
gamma2 float

Gamma2 parameter.

0.9
theta float

Theta parameter.

0.999
t_max Optional[int]

Total number of steps.

None
eta_min float

Eta_min of CosineDecay.

0.5
weight_decay float

Weight decay (L2 penalty).

0.0
update_proj_gap int

Update projection gap.

1000
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/spam.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class StableSPAM(BaseOptimizer):
    r"""How to Train in 4-Bit More Stably than 16-Bit Adam.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        gamma1 (float): Gamma1 parameter.
        gamma2 (float): Gamma2 parameter.
        theta (float): Theta parameter.
        t_max (Optional[int]): Total number of steps.
        eta_min (float): Eta_min of CosineDecay.
        weight_decay (float): Weight decay (L2 penalty).
        update_proj_gap (int): Update projection gap.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        gamma1: float = 0.7,
        gamma2: float = 0.9,
        theta: float = 0.999,
        t_max: Optional[int] = None,
        eta_min: float = 0.5,
        weight_decay: float = 0.0,
        update_proj_gap: int = 1000,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(update_proj_gap, 'update_proj_gap')
        self.validate_non_negative(eps, 'eps')

        self.gamma1: float = betas[0] if gamma1 == -1.0 else gamma1
        self.gamma2: float = gamma2
        self.theta: float = theta
        self.t_max = t_max
        self.update_proj_gap = update_proj_gap
        self.warmup = CosineDecay(1.0, t_max, eta_min=eta_min) if t_max is not None else None
        self.maximize = maximize

        self.total_step: int = 0

        defaults: Defaults = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'StableSPAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)
                state['m_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
                state['v_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
                state['m_max_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.total_step += 1

        scale: float = self.warmup.get_death_rate(self.total_step) if self.warmup is not None else 1.0

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']
            beta1 *= scale

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            bias_correction2_sq: float = math.sqrt(bias_correction2)

            step_size: float = group['lr'] / bias_correction1

            theta_t: float = 1.0 - self.theta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                max_grad = torch.max(grad.abs())

                exp_avg, exp_avg_sq, m_max_t = state['exp_avg'], state['exp_avg_sq'], state['m_max_t']

                m_max_t.lerp_(max_grad, weight=1.0 - self.theta)

                m_max_hat = m_max_t / theta_t

                mask = grad.abs() > m_max_hat
                if mask.sum() > 0:
                    grad[mask].div_(max_grad).mul_(m_max_hat)

                grad_norm = torch.linalg.norm(grad)
                if grad_norm == 0:
                    continue

                m_norm_t, v_norm_t = state['m_norm_t'], state['v_norm_t']
                m_norm_t.lerp_(grad_norm, weight=1.0 - self.gamma1 * scale)
                v_norm_t.lerp_(grad_norm.pow(2), weight=1.0 - self.gamma2)

                m_norm_hat = m_norm_t / (1.0 - (self.gamma1 * scale) ** group['step'])
                v_norm_hat = v_norm_t / (1.0 - self.gamma2 ** group['step'])

                c_norm_t = m_norm_hat.div_(v_norm_hat.sqrt_().add_(group['eps']))

                grad.div_(grad_norm).mul_(c_norm_t)

                if self.update_proj_gap > 0 and self.total_step % self.update_proj_gap == 0:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)
                    group['step'] = 1

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

SWATS

Bases: BaseOptimizer

Improving Generalization Performance by Switching from Adam to SGD.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

Whether to fix weight decay.

False
ams_bound bool

Whether to use the AMSBound variant of this algorithm from the paper.

False
nesterov bool

Enables Nesterov momentum.

False
eps float

Term added to the denominator to improve numerical stability.

1e-06
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/swats.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class SWATS(BaseOptimizer):
    """Improving Generalization Performance by Switching from Adam to SGD.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        ams_bound (bool): Whether to use the AMSBound variant of this algorithm from the paper.
        nesterov (bool): Enables Nesterov momentum.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        betas: Betas = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        nesterov: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'nesterov': nesterov,
            'phase': 'adam',
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SWATS'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg2'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * math.sqrt(bias_correction2),
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['phase'] == 'sgd':
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(grad)

                    buf = state['momentum_buffer']
                    buf.mul_(beta1).add_(grad)

                    update = buf.clone()
                    update.mul_(1.0 - beta1)

                    if group['nesterov']:
                        update.add_(buf, alpha=beta1)

                    p.add_(update, alpha=-group['lr'])

                    continue

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                perturb = exp_avg.clone()
                perturb.div_(de_nom).mul_(-step_size)

                p.add_(perturb)

                perturb_view = perturb.view(-1)
                pg = perturb_view.dot(grad.view(-1))

                if pg != 0:
                    scaling = perturb_view.dot(perturb_view).div_(-pg)

                    exp_avg2 = state['exp_avg2']
                    exp_avg2.mul_(beta2).add_(scaling, alpha=1.0 - beta2)

                    corrected_exp_avg = exp_avg2 / bias_correction2

                    if (
                        group['step'] > 1
                        and corrected_exp_avg > 0.0
                        and corrected_exp_avg.allclose(scaling, rtol=group['eps'])
                    ):
                        group['phase'] = 'sgd'
                        group['lr'] = corrected_exp_avg.item()

        return loss

TAM

Bases: BaseOptimizer

Torque-Aware Momentum.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
momentum float

Coefficient used for computing running averages of gradient.

0.9
decay_rate float

Smoothing decay rate.

0.9
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Whether to fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

1e-08
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/tam.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class TAM(BaseOptimizer):
    """Torque-Aware Momentum.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        momentum (float): Coefficient used for computing running averages of gradient.
        decay_rate (float): Smoothing decay rate.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        momentum: float = 0.9,
        decay_rate: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(decay_rate, 'decay_rate', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'momentum': momentum,
            'decay_rate': decay_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'TAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(grad)
                state['momentum_buffer'] = grad.clone()

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            momentum: float = group['momentum']
            decay_rate: float = group['decay_rate']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s, momentum_buffer = state['s'], state['momentum_buffer']

                corr = normalize(momentum_buffer, p=2.0, dim=0).mul_(normalize(grad, p=2.0, dim=0))
                s.mul_(decay_rate).add_(corr, alpha=1.0 - decay_rate)

                d = ((1.0 + s) / 2.0).add_(group['eps']).mul_(grad)

                momentum_buffer.mul_(momentum).add_(d)

                self.apply_weight_decay(
                    p,
                    grad,
                    group['lr'],
                    group['weight_decay'],
                    group['weight_decouple'],
                    group['fixed_decay'],
                )

                p.add_(momentum_buffer, alpha=-group['lr'])

        return loss

Tiger

Bases: BaseOptimizer

A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.001
beta float

Coefficient used for computing running averages of gradient and the squared Hessian trace.

0.965
weight_decay float

Weight decay (L2 penalty).

0.01
weight_decouple bool

Whether the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Whether to fix weight decay.

False
foreach Optional[bool]

Whether to use foreach (multi-tensor) operations for speed. None means auto-detect based on device (True for CUDA, False otherwise).

None
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/tiger.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class Tiger(BaseOptimizer):
    r"""A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        beta (float): Coefficient used for computing running averages of gradient and the squared Hessian trace.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        foreach (Optional[bool]): Whether to use foreach (multi-tensor) operations for speed.
            None means auto-detect based on device (True for CUDA, False otherwise).
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-3,
        beta: float = 0.965,
        weight_decay: float = 0.01,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        foreach: Optional[bool] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize
        self.foreach = foreach

        defaults: Defaults = {
            'lr': lr,
            'beta': beta,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'foreach': foreach,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Tiger'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)

    def _can_use_foreach(self, group: ParamGroup) -> bool:
        if group.get('foreach') is False:
            return False

        return self.can_use_foreach(group, group.get('foreach'))

    def _step_foreach(
        self,
        group: ParamGroup,
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        exp_avgs: List[torch.Tensor],
    ) -> None:
        if self.maximize:
            torch._foreach_neg_(grads)

        self.apply_weight_decay_foreach(
            params=params,
            grads=grads,
            lr=group['lr'],
            weight_decay=group['weight_decay'],
            weight_decouple=group['weight_decouple'],
            fixed_decay=group['fixed_decay'],
        )

        torch._foreach_lerp_(exp_avgs, grads, weight=1.0 - group['beta'])

        updates = torch._foreach_sign(exp_avgs)

        torch._foreach_add_(params, updates, alpha=-group['lr'])

    def _step_per_param(self, group: ParamGroup) -> None:
        beta = group['beta']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad

            self.maximize_gradient(grad, maximize=self.maximize)

            state = self.state[p]

            self.apply_weight_decay(
                p=p,
                grad=grad,
                lr=group['lr'],
                weight_decay=group['weight_decay'],
                weight_decouple=group['weight_decouple'],
                fixed_decay=group['fixed_decay'],
            )

            exp_avg = state['exp_avg']
            exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)

            p.add_(torch.sign(exp_avg) if not torch.is_complex(exp_avg) else torch.sgn(exp_avg), alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            if self._can_use_foreach(group):
                params, grads, state_dict = self.collect_trainable_params(group, self.state, state_keys=['exp_avg'])
                if params:
                    self._step_foreach(group, params, grads, state_dict['exp_avg'])
            else:
                self._step_per_param(group)

        return loss

TRAC

Bases: BaseOptimizer

A Parameter-Free Optimizer for Lifelong Reinforcement Learning.

Parameters:

Name Type Description Default
optimizer OptimizerInstanceOrClass

Base optimizer.

required
betas List[float]

List of beta values.

(0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999)
num_coefs int

Number of polynomial coefficients to use in the approximation.

128
s_prev float

Initial scale value.

1e-08
eps float

Term added to the denominator to improve numerical stability.

1e-08
Example

model = YourModel() optimizer = TRAC(AdamW(model.parameters()))

for input, output in data: optimizer.zero_grad() loss = loss_fn(model(input), output) loss.backward() optimizer.step()

Source code in pytorch_optimizer/optimizer/trac.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class TRAC(BaseOptimizer):
    """A Parameter-Free Optimizer for Lifelong Reinforcement Learning.

    Args:
        optimizer (OptimizerInstanceOrClass): Base optimizer.
        betas (List[float]): List of beta values.
        num_coefs (int): Number of polynomial coefficients to use in the approximation.
        s_prev (float): Initial scale value.
        eps (float): Term added to the denominator to improve numerical stability.

    Example:
        model = YourModel()
        optimizer = TRAC(AdamW(model.parameters()))

        for input, output in data:
            optimizer.zero_grad()
            loss = loss_fn(model(input), output)
            loss.backward()
            optimizer.step()

    """

    def __init__(
        self,
        optimizer: OptimizerInstanceOrClass,
        betas: List[float] = (0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999),
        num_coefs: int = 128,
        s_prev: float = 1e-8,
        eps: float = 1e-8,
        **kwargs,
    ):
        self.validate_positive(num_coefs, 'num_coefs')
        self.validate_non_negative(s_prev, 's_prev')
        self.validate_non_negative(eps, 'eps')

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self.betas = betas
        self.s_prev = s_prev
        self.eps = eps

        self.erf: nn.Module = ERF1994(num_coefs=num_coefs)
        self.f_term: torch.Tensor = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))

        self.defaults: Defaults = self.optimizer.defaults

    def __str__(self) -> str:
        return 'TRAC'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @property
    def state(self) -> State:
        return self.optimizer.state

    def state_dict(self) -> State:
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: State) -> None:
        self.optimizer.load_state_dict(state_dict)

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        updates: Dict[torch.Tensor, torch.Tensor] = kwargs.get('updates', {})

        for p in group['params']:
            self.state['trac'][p] = updates[p].clone()

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    @torch.no_grad()
    def erf_imag(self, x: torch.Tensor) -> torch.Tensor:
        if not torch.is_floating_point(x):
            x = x.real.to(torch.float32)

        ix = torch.complex(torch.zeros_like(x), x)

        return self.erf(ix).imag

    @torch.no_grad()
    def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
        updates, grads = {}, {}

        for group in self.param_groups:
            for p in group['params']:
                updates[p] = p.clone()
                grads[p] = p.grad.clone() if p.grad is not None else None

        return updates, grads

    @torch.no_grad()
    def trac_step(self, updates: Dict, grads: Dict) -> None:
        self.state['trac']['step'] += 1

        deltas = {}

        device = self.param_groups[0]['params'][0].device

        s = self.state['trac']['s']
        h = torch.zeros((1,), device=device)
        for group in self.param_groups:
            for p in group['params']:
                if grads[p] is None:
                    continue

                theta_ref = self.state['trac'][p]
                update = updates[p]

                deltas[p] = (update - theta_ref) / s.add(self.eps)
                update.neg_().add_(p)

                grad, delta = grads[p], deltas[p]

                product = torch.dot(delta.flatten(), grad.flatten())
                h.add_(product)

                delta.add_(update)

                p.copy_(theta_ref)

        betas = self.state['trac']['betas']
        variance = self.state['trac']['variance']
        sigma = self.state['trac']['sigma']

        variance.mul_(betas.pow(2)).add_(h.pow(2))
        sigma.mul_(betas).sub_(h)

        term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps)).mul_(self.f_term)
        s.copy_(torch.sum(term))

        scale = max(s, 0.0)

        for group in self.param_groups:
            for p in group['params']:
                if grads[p] is None:
                    continue

                p.add_(deltas[p] * scale)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        # TODO(kozistr): backup is first to get the delta of param and grad, but it does not work.
        with torch.enable_grad():
            loss = self.optimizer.step(closure)

        updates, grads = self.backup_params_and_grads()

        if 'trac' not in self.state:
            device = self.param_groups[0]['params'][0].device

            self.state['trac'] = {
                'betas': torch.tensor(self.betas, device=device),
                's': torch.zeros(1, device=device),
                'variance': torch.zeros(len(self.betas), device=device),
                'sigma': torch.full((len(self.betas),), 1e-8, device=device),
                'step': 0,
            }

            for group in self.param_groups:
                self.init_group(group, updates=updates)

        self.trac_step(updates, grads)

        return loss

VSGD

Bases: BaseOptimizer

Variational Stochastic Gradient Descent for Deep Neural Networks.

Parameters:

Name Type Description Default
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

learning rate.

0.1
ghattg float

prior variance ratio between ghat and g, Var(ghat_t-g_t)/Var(g_t-g_{t-1}).

30.0
ps float

prior strength.

1e-08
tau1 float

remember rate for the gamma parameters of g.

0.81
tau2 float

remember rate for the gamma parameter of ghat.

0.9
weight_decay float

weight decay (L2 penalty).

0.0
weight_decouple bool

optimizer uses decoupled weight decay as in AdamW.

True
eps float

term added to denominator to improve numerical stability.

1e-08
maximize bool

maximize the objective instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
class VSGD(BaseOptimizer):
    """Variational Stochastic Gradient Descent for Deep Neural Networks.

    Args:
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): learning rate.
        ghattg (float): prior variance ratio between ghat and g, Var(ghat_t-g_t)/Var(g_t-g_{t-1}).
        ps (float): prior strength.
        tau1 (float): remember rate for the gamma parameters of g.
        tau2 (float): remember rate for the gamma parameter of ghat.
        weight_decay (float): weight decay (L2 penalty).
        weight_decouple (bool): optimizer uses decoupled weight decay as in AdamW.
        eps (float): term added to denominator to improve numerical stability.
        maximize (bool): maximize the objective instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-1,
        ghattg: float = 30.0,
        ps: float = 1e-8,
        tau1: float = 0.81,
        tau2: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(ghattg, 'ghattg')
        self.validate_non_negative(ps, 'ps')
        self.validate_non_negative(tau1, 'tau1')
        self.validate_non_negative(tau2, 'tau2')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'tau1': tau1,
            'tau2': tau2,
            'pa2': 2.0 * ps + 1.0 + 1e-4,
            'pbg2': 2.0 * ps,
            'pbhg2': 2.0 * ghattg * ps,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'VSGD'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['mug'] = torch.zeros_like(p)
                state['bg'] = torch.zeros_like(p)
                state['bhg'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            pa2, pbg2, pbhg2 = group['pa2'], group['pbg2'], group['pbhg2']

            rho1: float = math.pow(group['step'], -group['tau1'])
            rho2: float = math.pow(group['step'], -group['tau2'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                bg, bhg = state['bg'], state['bhg']

                if group['step'] == 1:
                    sg = pbg2 / (pa2 - 1.0)
                    shg = pbhg2 / (pa2 - 1.0)
                else:
                    sg = bg / pa2
                    shg = bhg / pa2

                mug = state['mug']
                mug_prev = mug.clone()

                mug.mul_(shg).add_(grad * sg).div_(sg + shg)

                sigg = (sg * shg) / (sg + shg)
                mug_sq = mug.pow(2).add_(sigg)

                bg2 = pbg2 + mug_sq - 2.0 * mug * mug_prev + mug_prev.pow(2)
                bhg2 = pbhg2 + mug_sq - 2.0 * grad * mug + grad.pow(2)

                bg.mul_(1.0 - rho1).add_(bg2, alpha=rho1)
                bhg.mul_(1.0 - rho2).add_(bhg2, alpha=rho2)

                p.add_(group['lr'] / mug_sq.sqrt().add_(group['eps']) * mug, alpha=-1.0)

        return loss

WSAM

Bases: BaseOptimizer

Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

Parameters:

Name Type Description Default
model Union[Module, DataParallel]

the model instance. DDP model is recommended to make model.no_sync to work.

required
params ParamsT

iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

base optimizer.

required
rho float

size of the neighborhood for computing the max loss.

0.05
gamma float

weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.

0.9
adaptive bool

element-wise adaptive SAM.

False
decouple bool

whether to perform a decoupled sharpness regularization.

True
max_norm Optional[float]

max norm of the gradients.

None
eps float

term added to the denominator of WSAM to improve numerical stability.

1e-12
kwargs Dict

parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
class WSAM(BaseOptimizer):
    """Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

    Args:
        model (Union[torch.nn.Module, torch.nn.DataParallel]): the model instance. DDP model is recommended to make
            `model.no_sync` to work.
        params (ParamsT): iterable of parameters to optimize or dicts defining parameter groups.
        base_optimizer (Optimizer): base optimizer.
        rho (float): size of the neighborhood for computing the max loss.
        gamma (float): weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.
        adaptive (bool): element-wise adaptive SAM.
        decouple (bool): whether to perform a decoupled sharpness regularization.
        max_norm (Optional[float]): max norm of the gradients.
        eps (float): term added to the denominator of WSAM to improve numerical stability.
        kwargs (Dict): parameters for optimizer.

    """

    def __init__(
        self,
        model: Union[nn.Module, DistributedDataParallel],
        params: ParamsT,
        base_optimizer: OptimizerType,
        rho: float = 0.05,
        gamma: float = 0.9,
        adaptive: bool = False,
        decouple: bool = True,
        max_norm: Optional[float] = None,
        eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')

        self.model = model
        self.decouple = decouple
        self.max_norm = max_norm

        alpha: float = gamma / (1.0 - gamma)

        defaults: Defaults = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps, **kwargs}

        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'WSAM'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device)

        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + group['sam_eps'])

            for p in group['params']:
                if p.grad is None:
                    continue

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)

                p.add_(e_w)

                self.state[p]['e_w'] = e_w

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, op=ReduceOp.AVG)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['grad'] = p.grad.clone()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, ReduceOp.AVG)

                p.add_(self.state[p]['e_w'], alpha=-1.0)

        if self.max_norm is not None:
            clip_grad_norm_(self.model.parameters(), self.max_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if not self.decouple:
                    p.grad.mul_(group['alpha']).add_(self.state[p]['grad'], alpha=1.0 - group['alpha'])
                else:
                    self.state[p]['sharpness'] = p.grad.clone() - self.state[p]['grad']
                    p.grad.mul_(0.0).add_(self.state[p]['grad'], alpha=1.0)

        self.base_optimizer.step()

        if self.decouple:
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    p.add_(self.state[p]['sharpness'], alpha=-group['lr'] * group['alpha'])

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: Closure = None):
        if closure is None:
            raise NoClosureError(str(self))

        closure = torch.enable_grad()(closure)

        enable_running_stats(self.model)
        loss = closure()  # pyright: ignore[reportOptionalCall]

        self.first_step(zero_grad=True)

        disable_running_stats(self.model)
        closure()  # pyright: ignore[reportOptionalCall]

        self.second_step()

        return loss

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

Yogi

Bases: BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params ParamsT

Iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

Learning rate.

0.01
betas Betas

Coefficients used for computing running averages of gradient and the squared Hessian trace.

(0.9, 0.999)
initial_accumulator float

Initial values for first and second moments.

1e-06
weight_decay float

Weight decay (L2 penalty).

0.0
weight_decouple bool

Whether the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

Whether to fix weight decay.

False
eps float

Term added to the denominator to improve numerical stability.

0.001
maximize bool

Maximize the objective with respect to the parameters instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/yogi.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class Yogi(BaseOptimizer):
    r"""Decoupled Weight Decay Regularization.

    Args:
        params (ParamsT): Iterable of parameters to optimize or dicts defining parameter groups.
        lr (float): Learning rate.
        betas (Betas): Coefficients used for computing running averages of gradient and the squared Hessian trace.
        initial_accumulator (float): Initial values for first and second moments.
        weight_decay (float): Weight decay (L2 penalty).
        weight_decouple (bool): Whether the optimizer uses decoupled weight decay as in AdamW.
        fixed_decay (bool): Whether to fix weight decay.
        eps (float): Term added to the denominator to improve numerical stability.
        maximize (bool): Maximize the objective with respect to the parameters instead of minimizing.

    """

    def __init__(
        self,
        params: ParamsT,
        lr: float = 1e-2,
        betas: Betas = (0.9, 0.999),
        initial_accumulator: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-3,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: Defaults = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'initial_accumulator': initial_accumulator,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Yogi'

    def init_group(self, group: ParamGroup, **kwargs) -> None:
        if 'step' not in group:
            group['step'] = 0

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.full_like(grad, fill_value=group['initial_accumulator'])
                state['exp_avg_sq'] = torch.full_like(grad, fill_value=group['initial_accumulator'])

    @torch.no_grad()
    def step(self, closure: Closure = None) -> Loss:
        loss: Loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            self.init_group(group)
            group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False), step_size=group['lr'], bias_correction1=bias_correction1
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                grad_p2 = grad.mul(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.addcmul_(
                    (
                        (exp_avg_sq - grad_p2).sign_()
                        if not torch.is_complex(exp_avg_sq)
                        else (exp_avg_sq - grad_p2).sgn_()
                    ),
                    grad_p2,
                    value=-(1.0 - beta2),
                )

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss