Skip to content

Optimizers

create_optimizer(model, optimizer_name, lr=0.001, weight_decay=0.0, wd_ban_list=('bias', 'LayerNorm.bias', 'LayerNorm.weight'), use_lookahead=False, use_orthograd=False, **kwargs)

Build optimizer.

Parameters:

Name Type Description Default
model Module

nn.Module. model.

required
optimizer_name str

str. name of optimizer.

required
lr float

float. learning rate.

0.001
weight_decay float

float. weight decay.

0.0
wd_ban_list List[str]

List[str]. weight decay ban list by layer.

('bias', 'LayerNorm.bias', 'LayerNorm.weight')
use_lookahead bool

bool. use Lookahead.

False
use_orthograd bool

bool. use OrthoGrad.

False
Source code in pytorch_optimizer/optimizer/__init__.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def create_optimizer(
    model: nn.Module,
    optimizer_name: str,
    lr: float = 1e-3,
    weight_decay: float = 0.0,
    wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
    use_lookahead: bool = False,
    use_orthograd: bool = False,
    **kwargs,
) -> Optimizer:
    r"""Build optimizer.

    :param model: nn.Module. model.
    :param optimizer_name: str. name of optimizer.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay.
    :param wd_ban_list: List[str]. weight decay ban list by layer.
    :param use_lookahead: bool. use Lookahead.
    :param use_orthograd: bool. use OrthoGrad.
    """
    optimizer_name = optimizer_name.lower()

    parameters = (
        get_optimizer_parameters(model, weight_decay, wd_ban_list) if weight_decay > 0.0 else model.parameters()
    )

    optimizer_class: OPTIMIZER = load_optimizer(optimizer_name)

    if optimizer_name == 'alig':
        optimizer = optimizer_class(parameters, max_lr=lr, **kwargs)
    elif optimizer_name in {'lomo', 'adalomo', 'adammini'}:
        optimizer = optimizer_class(model, lr=lr, **kwargs)
    else:
        optimizer = optimizer_class(parameters, lr=lr, **kwargs)

    if use_orthograd:
        optimizer = OrthoGrad(optimizer, **kwargs)

    if use_lookahead:
        if optimizer_name in ('ranger', 'ranger21', 'ranger25'):
            warn(f'{optimizer} already has a Lookahead variant.', UserWarning, 1)
            return optimizer

        optimizer = Lookahead(
            optimizer,
            k=kwargs.get('k', 5),
            alpha=kwargs.get('alpha', 0.5),
            pullback_momentum=kwargs.get('pullback_momentum', 'none'),
        )

    return optimizer

get_optimizer_parameters(model_or_parameter, weight_decay, wd_ban_list=('bias', 'LayerNorm.bias', 'LayerNorm.weight'))

Get optimizer parameters while filtering specified modules.

Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only need to input LayerNorm to exclude weight decay from the layer norm layer(s).

Parameters:

Name Type Description Default
model_or_parameter Union[Module, List]

Union[nn.Module, List]. model or parameters.

required
weight_decay float

float. weight_decay.

required
wd_ban_list List[str]

List[str]. ban list not to set weight decay.

('bias', 'LayerNorm.bias', 'LayerNorm.weight')

Returns:

Type Description
PARAMETERS

PARAMETERS. new parameter list.

Source code in pytorch_optimizer/optimizer/__init__.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def get_optimizer_parameters(
    model_or_parameter: Union[nn.Module, List],
    weight_decay: float,
    wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> PARAMETERS:
    r"""Get optimizer parameters while filtering specified modules.

    Notice that, You can also ban by a module name level (e.g. LayerNorm) if you pass nn.Module instance. You just only
    need to input `LayerNorm` to exclude weight decay from the layer norm layer(s).

    :param model_or_parameter: Union[nn.Module, List]. model or parameters.
    :param weight_decay: float. weight_decay.
    :param wd_ban_list: List[str]. ban list not to set weight decay.
    :returns: PARAMETERS. new parameter list.
    """
    banned_parameter_patterns: Set[str] = set()

    if isinstance(model_or_parameter, nn.Module):
        for module_name, module in model_or_parameter.named_modules():
            for param_name, _ in module.named_parameters(recurse=False):
                full_param_name: str = f'{module_name}.{param_name}' if module_name else param_name
                if any(
                    banned in pattern for banned in wd_ban_list for pattern in (full_param_name, module._get_name())
                ):
                    banned_parameter_patterns.add(full_param_name)

        model_or_parameter = list(model_or_parameter.named_parameters())
    else:
        banned_parameter_patterns.update(wd_ban_list)

    return [
        {
            'params': [
                p
                for n, p in model_or_parameter
                if p.requires_grad and not any(nd in n for nd in banned_parameter_patterns)
            ],
            'weight_decay': weight_decay,
        },
        {
            'params': [
                p
                for n, p in model_or_parameter
                if p.requires_grad and any(nd in n for nd in banned_parameter_patterns)
            ],
            'weight_decay': 0.0,
        },
    ]

A2Grad

Bases: BaseOptimizer

Optimal Adaptive and Accelerated Stochastic Gradient Descent.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr Optional[float]

Optional[float]. learning rate. no needed.

None
beta float

float. beta.

10.0
lips float

float. Lipschitz constant.

10.0
rho float

float. represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.

0.5
variant VARIANTS

str. variant of A2Grad optimizer. 'uni', 'inc', 'exp'.

'uni'
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/a2grad.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class A2Grad(BaseOptimizer):
    r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: Optional[float]. learning rate. no needed.
    :param beta: float. beta.
    :param lips: float. Lipschitz constant.
    :param rho: float. represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.
    :param variant: str. variant of A2Grad optimizer. 'uni', 'inc', 'exp'.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: Optional[float] = None,
        beta: float = 10.0,
        lips: float = 10.0,
        rho: float = 0.5,
        variant: VARIANTS = 'uni',
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(lips, 'lips')
        self.validate_non_negative(rho, 'rho')
        self.validate_options(variant, 'variant', ['uni', 'inc', 'exp'])

        self.variant = variant
        self.maximize = maximize

        defaults: DEFAULTS = {'beta': beta, 'lips': lips}
        if variant == 'exp':
            defaults.update({'rho': rho})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'A2Grad'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['alpha_k'] = 1.0
                state['v_k'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                state['avg_grad'] = grad.clone()
                state['x_k'] = p.clone()
                if self.variant == 'exp':
                    state['v_kk'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            gamma_k: float = 2.0 * group['lips'] / (group['step'] + 1)
            alpha_k_1: float = 2.0 / (group['step'] + 3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                avg_grad, v_k, x_k = state['avg_grad'], state['v_k'], state['x_k']
                avg_grad.add_(grad - avg_grad, alpha=group['step'] + 1)

                delta_k = grad.clone()
                delta_k.add_(avg_grad, alpha=-1.0)

                delta_k_sq = delta_k.pow(2).sum()

                if self.variant in ('uni', 'inc'):
                    if self.variant == 'inc':
                        v_k.mul_((group['step'] / (group['step'] + 1)) ** 2)
                    v_k.add_(delta_k_sq)
                else:
                    v_kk = state['v_kk']

                    v_kk.mul_(group['rho']).add_(delta_k_sq, alpha=1.0 - group['rho'])
                    torch.max(v_kk, v_k, out=v_k)

                h_k = v_k.sqrt()
                if self.variant != 'uni':
                    h_k.mul_(math.sqrt(group['step'] + 1))

                coefficient = -1.0 / (gamma_k + group['beta'] * h_k.item())

                x_k.add_(grad, alpha=coefficient)

                p.mul_(1.0 - alpha_k_1).add_(x_k, alpha=alpha_k_1)
                p.add_(grad, alpha=(1.0 - alpha_k_1) * state['alpha_k'] * coefficient)

                state['alpha_k'] = alpha_k_1

        return loss

AdaBelief

Bases: BaseOptimizer

Adapting Step-sizes by the Belief in Observed Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-16
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adabelief.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class AdaBelief(BaseOptimizer):
    r"""Adapting Step-sizes by the Belief in Observed Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: number of SMA threshold (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        eps: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaBelief'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

                p, grad, exp_avg, exp_avg_var = self.view_as_real(p, grad, exp_avg, exp_avg_var)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                grad_residual = grad - exp_avg
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )

                if not group['rectify']:
                    de_nom.div_(bias_correction2_sq)
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

AdaBound

Bases: BaseOptimizer

Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
final_lr float

float. final learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
gamma float

float. convergence speed of the bound functions.

0.001
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adabound.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class AdaBound(BaseOptimizer):
    r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param final_lr: float. final learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param gamma: float. convergence speed of the bound functions.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        final_lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        gamma: float = 1e-3,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'final_lr': final_lr,
            'gamma': gamma,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'AdaBound'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group, base_lr in zip(self.param_groups, self.base_lrs):
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            final_lr: float = group['final_lr'] * group['lr'] / base_lr
            lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1))
            upper_bound: float = final_lr * (1 + 1 / (group['gamma'] * group['step']))

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                update = torch.full_like(de_nom, fill_value=step_size)
                update.div_(de_nom).clamp_(min=lower_bound, max=upper_bound).mul_(exp_avg)

                p.add_(-update)

        return loss

AdaDelta

Bases: BaseOptimizer

An Adaptive Learning Rate Method.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
rho float

float. coefficient used for computing a running average of squared gradients.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adadelta.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class AdaDelta(BaseOptimizer):
    r"""An Adaptive Learning Rate Method.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param rho: float. coefficient used for computing a running average of squared gradients.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        rho: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(rho, 'rho', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'rho': rho,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaDelta'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['square_avg'] = torch.zeros_like(p)
                state['acc_delta'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            rho: float = group['rho']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                square_avg, acc_delta = state['square_avg'], state['acc_delta']

                p, grad, square_avg, acc_delta = self.view_as_real(p, grad, square_avg, acc_delta)

                square_avg.mul_(rho).addcmul_(grad, grad, value=1.0 - rho)

                std = square_avg.add(group['eps']).sqrt_()
                delta = acc_delta.add(group['eps']).sqrt_().div_(std).mul_(grad)

                acc_delta.mul_(rho).addcmul_(delta, delta, value=1.0 - rho)

                p.add_(delta, alpha=-group['lr'])

        return loss

AdaFactor

Bases: BaseOptimizer

Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr Optional[float]

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. if beta1 is None, first momentum will be skipped.

(0.9, 0.999)
decay_rate float

float. coefficient used to compute running averages of square gradient.

-0.8
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
clip_threshold float

float. threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

bool. whether to use the AMSBound variant.

False
scale_parameter bool

bool. if true, learning rate is scaled by root-mean-square of parameter.

True
relative_step bool

bool. if true, time-dependent learning rate is computed instead of external learning rate.

True
warmup_init bool

bool. time-dependent learning rate computation depends on whether warm-up initialization is being used.

False
eps1 float

float. term added to the denominator to improve numerical stability.

1e-30
eps2 float

float. term added to the denominator to improve numerical stability.

0.001
momentum_dtype dtype

torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while reducing optimize overhead from 2-fold to 1.5-fold.

bfloat16
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adafactor.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class AdaFactor(BaseOptimizer):
    r"""Adaptive Learning Rates with Sublinear Memory Cost with some tweaks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
        if beta1 is None, first momentum will be skipped.
    :param decay_rate: float. coefficient used to compute running averages of square gradient.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param clip_threshold: float. threshold of root-mean-square of final gradient update.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param scale_parameter: bool. if true, learning rate is scaled by root-mean-square of parameter.
    :param relative_step: bool. if true, time-dependent learning rate is computed instead of external learning rate.
    :param warmup_init: bool. time-dependent learning rate computation depends on whether warm-up initialization is
        being used.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    :param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
        half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
        reducing optimize overhead from 2-fold to 1.5-fold.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: Optional[float] = 1e-3,
        betas: BETAS = (0.9, 0.999),
        decay_rate: float = -0.8,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        scale_parameter: bool = True,
        relative_step: bool = True,
        warmup_init: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-3,
        momentum_dtype: torch.dtype = torch.bfloat16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.decay_rate = decay_rate
        self.clip_threshold = clip_threshold
        self.eps1 = eps1
        self.eps2 = eps2
        self.momentum_dtype = momentum_dtype
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'scale_parameter': scale_parameter,
            'relative_step': relative_step,
            'warmup_init': warmup_init,
            'eps1': eps1,
            'eps2': eps2,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaFactor'

    def init_group(self, group: GROUP, **kwargs) -> None:
        beta1: float = kwargs.get('beta1')

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_shape: Tuple[int, ...] = grad.shape
            factored: bool = self.get_options(grad_shape)

            if len(state) == 0:
                if beta1 is not None:
                    state['exp_avg'] = torch.zeros_like(p, dtype=self.momentum_dtype)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                state['RMS'] = 0.0

    def get_lr(
        self, lr: float, step: int, rms: float, relative_step: bool, warmup_init: bool, scale_parameter: bool
    ) -> float:
        r"""Get AdaFactor learning rate."""
        relative_step_size: float = lr
        if relative_step:
            min_step: float = 1e-6 * step if warmup_init else 1e-2
            relative_step_size = min(min_step, 1.0 / math.sqrt(step))

        param_scale: float = 1.0 if scale_parameter else max(self.eps2, rms)

        return param_scale * relative_step_size

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            if 'step' not in group:
                self.init_group(group, beta1=beta1)
                group['step'] = 1
            else:
                group['step'] += 1

            beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                state['RMS'] = self.get_rms(p)

                lr: float = self.get_lr(
                    lr=group['lr'],
                    step=group['step'],
                    rms=state['RMS'],
                    relative_step=group['relative_step'],
                    warmup_init=group['warmup_init'],
                    scale_parameter=group['scale_parameter'],
                )

                update = torch.mul(grad, grad).add_(self.eps1)

                if factored:
                    exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                    exp_avg_sq_col.mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                    self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                    exp_avg_sq.clamp_(max=beta2)

                    torch.rsqrt(exp_avg_sq, out=update)

                if group['ams_bound']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat']
                    torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
                    torch.rsqrt(exp_avg_sq_hat / beta2_t, out=update)

                update.mul_(grad)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)

                if beta1 is not None:
                    exp_avg = state['exp_avg']
                    exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

                    update = exp_avg.clone()
                    if group.get('cautious'):
                        self.apply_cautious(update, grad)

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(-update)

        return loss

get_lr(lr, step, rms, relative_step, warmup_init, scale_parameter)

Get AdaFactor learning rate.

Source code in pytorch_optimizer/optimizer/adafactor.py
125
126
127
128
129
130
131
132
133
134
135
136
def get_lr(
    self, lr: float, step: int, rms: float, relative_step: bool, warmup_init: bool, scale_parameter: bool
) -> float:
    r"""Get AdaFactor learning rate."""
    relative_step_size: float = lr
    if relative_step:
        min_step: float = 1e-6 * step if warmup_init else 1e-2
        relative_step_size = min(min_step, 1.0 / math.sqrt(step))

    param_scale: float = 1.0 if scale_parameter else max(self.eps2, rms)

    return param_scale * relative_step_size

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/adafactor.py
138
139
140
141
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

AdaGC

Bases: BaseOptimizer

Improving Training Stability for Large Language Model Pretraining.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
beta float

float. smoothing coefficient for EMA.

0.98
lambda_abs float

float. absolute clipping threshold to prevent unstable updates from gradient explosions.

1.0
lambda_rel float

float. relative clipping threshold to prevent unstable updates from gradient explosions.

1.05
warmup_steps int

int. warmup steps.

100
weight_decay float

float. weight decay (L2 penalty).

0.1
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adagc.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class AdaGC(BaseOptimizer):
    r"""Improving Training Stability for Large Language Model Pretraining.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param beta: float. smoothing coefficient for EMA.
    :param lambda_abs: float. absolute clipping threshold to prevent unstable updates from gradient explosions.
    :param lambda_rel: float. relative clipping threshold to prevent unstable updates from gradient explosions.
    :param warmup_steps: int. warmup steps.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        beta: float = 0.98,
        lambda_abs: float = 1.0,
        lambda_rel: float = 1.05,
        warmup_steps: int = 100,
        weight_decay: float = 1e-1,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(beta, 'beta', 0.0, 1.0, '[)')
        self.validate_positive(lambda_abs, 'lambda_abs')
        self.validate_positive(lambda_rel, 'lambda_rel')
        self.validate_non_negative(warmup_steps, 'warmup_steps')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'beta': beta,
            'lambda_abs': lambda_abs,
            'lambda_rel': lambda_rel,
            'warmup_steps': warmup_steps,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaGC'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)
                state['gamma'] = torch.empty((1,), device=grad.device, dtype=grad.dtype)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq, gamma = state['exp_avg'], state['exp_avg_sq'], state['gamma']

                if group['step'] < group['warmup_steps']:
                    grad_norm = get_global_gradient_norm(self.param_groups).add_(group['eps'])

                    h_t = min(group['lambda_abs'] / grad_norm, 1.0)
                    g_hat = grad.mul(h_t)

                    g_hat_norm = g_hat.norm()

                    gamma.copy_(g_hat_norm if group['step'] == 1 else min(gamma, g_hat_norm))
                else:
                    h_t = min(group['lambda_rel'] * gamma / grad.norm(), 1.0)
                    g_hat = grad.mul(h_t)

                    gamma.mul_(group['beta']).add_(g_hat.norm(), alpha=1.0 - group['beta'])

                exp_avg.mul_(beta1).add_(g_hat, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(g_hat, g_hat, value=1.0 - beta2)

                update = (exp_avg / bias_correction1) / exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.add_(update, alpha=-group['lr'])

        return loss

AdaHessian

Bases: BaseOptimizer

An Adaptive Second Order Optimizer for Machine Learning.

Requires `loss.backward(create_graph=True)` in order to calculate hessians.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
hessian_power float

float. exponent of the hessian trace.

1.0
update_period int

int. number of steps after which to apply hessian approximation.

1
num_samples int

int. times to sample z for the approximation of the hessian trace.

1
hessian_distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution to initialize hessian.

'rademacher'
eps float

float. term added to the denominator to improve numerical stability.

1e-16
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adahessian.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class AdaHessian(BaseOptimizer):
    r"""An Adaptive Second Order Optimizer for Machine Learning.

        Requires `loss.backward(create_graph=True)` in order to calculate hessians.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param hessian_power: float. exponent of the hessian trace.
    :param update_period: int. number of steps after which to apply hessian approximation.
    :param num_samples: int. times to sample `z` for the approximation of the hessian trace.
    :param hessian_distribution: HUTCHINSON_G. type of distribution to initialize hessian.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        hessian_power: float = 1.0,
        update_period: int = 1,
        num_samples: int = 1,
        hessian_distribution: HUTCHINSON_G = 'rademacher',
        eps: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')
        self.validate_range(hessian_power, 'Hessian Power', 0, 1, range_type='(]')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'hessian_power': hessian_power,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaHessian'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_hessian_diag_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(group.get('adam_debias', False), group['lr'], bias_correction1)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1.0 - beta2)

                de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(group['hessian_power'] / 2).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

AdaLOMO

Bases: BaseOptimizer

Low-memory Optimization with Adaptive Learning Rate.

Parameters:

Name Type Description Default
model Module

nn.Module. pytorch model.

required
lr float

float. learning rate.

0.001
weight_decay float

float. weight decay (L2 penalty).

0.0
loss_scale float

float. loss scale.

2.0 ** 10
clip_threshold float

float. threshold of root-mean-square of final gradient update.

1.0
decay_rate float

float. coefficient used to compute running averages of square gradient.

-0.8
clip_grad_norm Optional[float]

Optional[float]. clip grad norm.

None
clip_grad_value Optional[float]

Optional[float]. clip grad value.

None
eps1 float

float. term added to the denominator to improve numerical stability.

1e-30
eps2 float

float. term added to the denominator to improve numerical stability.

0.001
Source code in pytorch_optimizer/optimizer/lomo.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
class AdaLOMO(BaseOptimizer):
    r"""Low-memory Optimization with Adaptive Learning Rate.

    :param model: nn.Module. pytorch model.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param loss_scale: float. loss scale.
    :param clip_threshold: float. threshold of root-mean-square of final gradient update.
    :param decay_rate: float. coefficient used to compute running averages of square gradient.
    :param clip_grad_norm: Optional[float]. clip grad norm.
    :param clip_grad_value: Optional[float]. clip grad value.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        loss_scale: float = 2.0 ** 10,
        clip_threshold: float = 1.0,
        decay_rate: float = -0.8,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
        eps1: float = 1e-30,
        eps2: float = 1e-3,
        **kwargs,
    ) -> None:  # fmt: skip
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(loss_scale, 'loss_scale')
        self.validate_non_negative(clip_threshold, 'clip_threshold')
        self.validate_non_negative(clip_grad_norm, 'clip_grad_norm')
        self.validate_non_negative(clip_grad_value, 'clip_grad_value')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_scale = loss_scale
        self.clip_threshold = clip_threshold
        self.decay_rate = decay_rate
        self.clip_grad_norm = clip_grad_norm
        self.clip_grad_value = clip_grad_value
        self.eps1 = eps1
        self.eps2 = eps2

        self.num_steps: int = 0
        self.gather_norm: bool = False
        self.grad_norms: List[torch.Tensor] = []
        self.clip_coef: Optional[float] = None

        self.local_rank: int = int(os.environ.get('LOCAL_RANK', '0'))
        self.zero3_enabled: bool = is_deepspeed_zero3_enabled()

        self.grad_func: Callable[[Any], Any] = self.fuse_update_zero3() if self.zero3_enabled else self.fuse_update()

        self.exp_avg_sq = {}
        self.exp_avg_sq_row = {}
        self.exp_avg_sq_col = {}

        self.initialize_states()

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'clip_grad_norm': clip_grad_norm,
            'clip_grad_value': clip_grad_value,
            'eps1': eps1,
            'eps2': eps2,
        }

        super().__init__(self.model.parameters(), defaults)

    def __str__(self) -> str:
        return 'AdaLOMO'

    def initialize_states(self) -> None:
        for n, p in self.model.named_parameters():
            if self.zero3_enabled:  # pragma: no cover
                if len(p.ds_shape) == 1:
                    self.exp_avg_sq[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32, device=p.device)
                else:
                    self.exp_avg_sq_row[n] = torch.zeros(p.ds_shape[0], dtype=torch.float32, device=p.device)
                    self.exp_avg_sq_col[n] = torch.zeros(p.ds_shape[1], dtype=torch.float32, device=p.device)
            elif len(p.shape) == 1:
                self.exp_avg_sq[n] = torch.zeros(p.shape[0], dtype=torch.float32, device=p.device)
            else:
                self.exp_avg_sq_row[n] = torch.zeros(p.shape[0], dtype=torch.float32, device=p.device)
                self.exp_avg_sq_col[n] = torch.zeros(p.shape[1], dtype=torch.float32, device=p.device)

            if p.requires_grad:
                p.register_hook(self.grad_func)

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    def fuse_update(self) -> Callable[[Any], Any]:
        @torch.no_grad()
        def func(x: Any) -> Any:
            for n, p in self.model.named_parameters():
                if not p.requires_grad or p.grad is None:
                    continue

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scale:
                    grad_fp32.div_(self.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    if self.clip_grad_value is not None and self.clip_grad_value > 0.0:
                        grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is not None:
                        grad_fp32.mul_(self.clip_coef)

                    beta2_t: float = 1.0 - math.pow(
                        self.num_steps, self.decay_rate if self.num_steps > 0 else -self.decay_rate
                    )

                    update = grad_fp32.pow(2).add_(self.eps1)

                    if len(p.shape) > 1:
                        self.exp_avg_sq_row[n].mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                        self.exp_avg_sq_col[n].mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                        self.approximate_sq_grad(self.exp_avg_sq_row[n], self.exp_avg_sq_col[n], update)
                        update.mul_(grad_fp32)
                    else:
                        self.exp_avg_sq[n].mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                        update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32)

                    update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0))

                    p_fp32 = p.to(torch.float32)
                    p_rms = torch.norm(p_fp32, 2.0) / math.sqrt(p.numel())

                    lr = self.lr * max(self.eps2, p_rms)

                    self.apply_weight_decay(
                        p,
                        grad_fp32,
                        lr,
                        self.weight_decay,
                        weight_decouple=True,
                        fixed_decay=False,
                    )

                    p_fp32.add_(grad_fp32, alpha=-lr)
                    p.copy_(p_fp32)

            return x

        return func

    def fuse_update_zero3(self) -> Callable[[Any], Any]:  # pragma: no cover
        @torch.no_grad()
        def func(x: torch.Tensor) -> torch.Tensor:
            for n, p in self.model.named_parameters():
                if p.grad is None:
                    continue

                all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scale:
                    grad_fp32.div_(self.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    partition_size: int = p.ds_tensor.numel()
                    start = partition_size * self.local_rank
                    end = min(start + partition_size, grad_fp32.numel())

                if self.clip_grad_value is not None:
                    grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
                    grad_fp32.mul_(self.clip_coef)

                beta2_t: float = 1.0 - math.pow(
                    self.num_steps, self.decay_rate if self.num_steps > 0 else -self.decay_rate
                )

                update = grad_fp32.pow(2).add_(self.eps1)

                if len(p.ds_shape) > 1:
                    self.exp_avg_sq_row[n].mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                    self.exp_avg_sq_col[n].mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                    self.approximate_sq_grad(self.exp_avg_sq_row[n], self.exp_avg_sq_col[n], update)
                    update.mul_(grad_fp32)
                else:
                    self.exp_avg_sq[n].mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                    update = self.exp_avg_sq[n].rsqrt().mul_(grad_fp32)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0))

                one_dim_update = update.view(-1)
                partitioned_update = one_dim_update.narrow(0, start, end - start)

                param_fp32 = p.ds_tensor.to(torch.float32)
                partitioned_p = param_fp32.narrow(0, 0, end - start)

                p_rms = torch.norm(partitioned_p, 2.0).pow_(2)
                all_reduce(p_rms, op=ReduceOp.SUM)
                p_rms.div_(p.ds_numel).sqrt_()

                lr = self.lr * max(self.eps2, p_rms)

                self.apply_weight_decay(
                    p=partitioned_p,
                    grad=grad_fp32,
                    lr=lr,
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                partitioned_p.add_(partitioned_update, alpha=-lr)

                p.ds_tensor[:end - start] = partitioned_p  # fmt: skip

            return x

        return func

    def fused_backward(self, loss, lr: float) -> None:
        self.lr = lr

        if self.loss_scale:
            loss = loss * self.loss_scale

        self.num_steps += 1

        loss.backward()

        self.grad_func(0)

    def grad_norm(self, loss) -> None:
        self.gather_norm = True
        self.grad_norms = []

        if self.loss_scale:
            loss = loss * self.loss_scale

        loss.backward(retain_graph=True)

        self.grad_func(0)

        with torch.no_grad():
            self.grad_norms = torch.stack(self.grad_norms)

            total_norm = torch.norm(self.grad_norms, 2.0)
            self.clip_coef = torch.clamp(float(self.clip_grad_norm) / (total_norm + 1e-6), max=1.0)

        self.gather_norm = False

Adai

Bases: BaseOptimizer

Disentangling the Effects of Adaptive Learning Rate and Momentum.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.1, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
stable_weight_decay bool

bool. perform stable weight decay.

False
dampening float

float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.

1.0
eps float

float. term added to the denominator to improve numerical stability.

0.001
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adai.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class Adai(BaseOptimizer):
    r"""Disentangling the Effects of Adaptive Learning Rate and Momentum.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param stable_weight_decay: bool. perform stable weight decay.
    :param dampening: float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.1, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        stable_weight_decay: bool = False,
        dampening: float = 1.0,
        eps: float = 1e-3,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'stable_weight_decay': stable_weight_decay,
            'dampening': dampening,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adai'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['beta1_prod'] = torch.ones_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            _, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                if not group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError()

        exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size

        for group in self.param_groups:
            beta0, beta2 = group['betas']

            beta0_dp: float = math.pow(beta0, 1.0 - group['dampening'])
            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg_sq_hat = exp_avg_sq / bias_correction2
                beta1 = (
                    1.0
                    - (exp_avg_sq_hat / exp_avg_sq_hat_mean).pow_(1.0 / (3.0 - 2.0 * group['dampening'])).mul_(beta0)
                ).clamp_(0.0, 1.0 - group['eps'])
                beta3 = (1.0 - beta1).pow_(group['dampening'])

                beta1_prod = state['beta1_prod']
                beta1_prod.mul_(beta1)

                exp_avg.mul_(beta1).addcmul_(beta3, grad)
                exp_avg_hat = exp_avg.div(1.0 - beta1_prod).mul_(beta0_dp)

                p.add_(exp_avg_hat, alpha=-group['lr'])

        return loss

Adalite

Bases: BaseOptimizer

Adalite optimizer.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
g_norm_min float

float.

1e-10
ratio_min float

float.

0.0001
tau float

float.

1.0
eps1 float

float. term added to the denominator to improve numerical stability.

1e-06
eps2 float

float. term added to the denominator to improve numerical stability.

1e-10
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adalite.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class Adalite(BaseOptimizer):
    r"""Adalite optimizer.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param g_norm_min: float.
    :param ratio_min: float.
    :param tau: float.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 1e-2,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        g_norm_min: float = 1e-10,
        ratio_min: float = 1e-4,
        tau: float = 1.0,
        eps1: float = 1e-6,
        eps2: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'g_norm_min': g_norm_min,
            'ratio_min': ratio_min,
            'tau': tau,
            'eps1': eps1,
            'eps2': eps2,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adalite'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if len(p.shape) < 2:
                    state['m_avg'] = torch.zeros_like(p)
                    state['v_avg'] = torch.zeros_like(p)
                else:
                    state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
                    state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

                    state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
                    state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
                    state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if sum(grad.shape) > 1:
                    trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min'])
                    grad.mul_(trust_ratio)

                if len(grad.shape) < 2:
                    m = state['m_avg']
                    v = state['v_avg']
                else:
                    r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :]
                    v = (r * c) / r.sum().clamp(min=group['eps2'])
                    m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r']

                m.lerp_(grad, 1.0 - beta1)
                v.lerp_((grad - m).square(), 1.0 - beta2)

                v_avg = v / (1.0 - beta2 ** group['step'])

                if len(grad.shape) == 2:
                    imp_c = softmax(v.mean(dim=1), dim=0)[:, None]
                    imp_r = softmax(v.mean(dim=0), dim=0)[None, :]
                    m.lerp_(grad, 1.0 - imp_c * imp_r)

                u = m.lerp(grad, 1.0 - beta1)

                if len(grad.shape) < 2:
                    state['m_avg'] = m
                    state['v_avg'] = v
                else:
                    state['v_avg_0'] = v.sum(dim=1)
                    state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2'])

                    imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None]
                    imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :]

                    c = ((m * imp_r).sum(dim=1))[:, None]
                    r = ((m * imp_c).sum(dim=0))[None, :]

                    s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2'])

                    state['m_avg_c'] = c
                    state['m_avg_r'] = r
                    state['m_avg_u'] = s

                u.div_((v_avg + group['eps1']).sqrt())

                u = u.reshape(p.shape)
                u.add_(p, alpha=group['weight_decay'])

                p.add_(u, alpha=-group['lr'])

        return loss

AdamMini

Bases: BaseOptimizer

Use Fewer Learning Rates To Gain More.

Parameters:

Name Type Description Default
model Module

nn.Module. model instance.

required
model_sharding bool

bool. set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1, 2, 3 in Deepspeed. Set to False if otherwise.

False
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.1
num_embeds int

int. number of embedding dimensions. could be unspecified if you are training non-transformer models.

2048
num_heads int

int. number of attention heads. could be unspecified if you are training non-transformer models.

32
num_query_groups Optional[int]

Optional[int]. number of query groups in Group Query Attention (GQA). if not specified, it will be equal to num_heads. could be unspecified if you are training non-transformer models.

None
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adam_mini.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class AdamMini(BaseOptimizer):  # pragma: no cover
    r"""Use Fewer Learning Rates To Gain More.

    :param model: nn.Module. model instance.
    :param model_sharding: bool. set to True if you are using model parallelism with more than 1 GPU, including FSDP
        and zero_1, 2, 3 in Deepspeed. Set to False if otherwise.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param num_embeds: int. number of embedding dimensions. could be unspecified if you are training non-transformer
        models.
    :param num_heads: int. number of attention heads. could be unspecified if you are training non-transformer models.
    :param num_query_groups: Optional[int]. number of query groups in Group Query Attention (GQA). if not specified, it
        will be equal to num_heads. could be unspecified if you are training non-transformer models.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.1,
        model_sharding: bool = False,
        num_embeds: int = 2048,
        num_heads: int = 32,
        num_query_groups: Optional[int] = None,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(num_embeds, 'num_embeds')
        self.validate_non_negative(num_heads, 'num_heads')
        self.validate_non_negative(eps, 'eps')

        self.num_query_groups: int = num_query_groups if num_query_groups is not None else num_embeds
        self.validate_mod(num_embeds, self.num_query_groups)

        self.world_size: int = torch.cuda.device_count()

        self.model = model
        self.model_sharding = model_sharding
        self.num_embeds = num_embeds
        self.num_heads = num_heads

        self.embed_blocks: Set[str] = {'embed', 'embd', 'wte', 'lm_head.weight', 'output.weight'}
        self.qk_blocks: Set[str] = {'k_proj.weight', 'q_proj.weight', 'wq.weight', 'wk.weight'}

        self.maximize = maximize

        groups = self.get_optimizer_groups(weight_decay)

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'eps': eps, **kwargs}

        super().__init__(groups, defaults)

    def __str__(self) -> str:
        return 'AdamMini'

    def get_optimizer_groups(self, weight_decay: float):
        groups = []
        for name, param in self.model.named_parameters():
            if not param.requires_grad:
                continue

            group = {
                'name': name,
                'params': param,
                'weight_decay': 0.0 if ('norm' in name or 'ln_f' in name) else weight_decay,
            }

            if any(block in name for block in self.qk_blocks):
                group['parameter_per_head'] = self.num_embeds * self.num_embeds // self.num_heads

            if 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
                group['n_head'] = self.num_heads
                group['q_per_kv'] = self.num_embeds // self.num_query_groups

            groups.append(group)

        return groups

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @staticmethod
    def step_embed(
        p,
        grad,
        state,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32)
            state['v'] = torch.zeros_like(p, dtype=torch.float32)

        m, v = state['m'], state['v']

        m.lerp_(grad, weight=1.0 - beta1)
        v.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        p.addcdiv_(m, h, value=-lr / bias_correction1)

    @staticmethod
    def step_attn_proj(
        p,
        grad,
        state,
        parameter_per_head: int,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32).view(-1, parameter_per_head)
            state['head'] = state['m'].shape[0]
            state['v_mean'] = torch.zeros(state['head'], device=state['m'].device)

        m, v = state['m'], state['v_mean']

        head: int = state['head']
        grad = grad.view(head, parameter_per_head)

        m.lerp_(grad, weight=1.0 - beta1)

        tmp_lr = torch.mean(grad * grad, dim=1).to(m.device)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        update = (1 / (h * bias_correction1)).view(head, 1).mul_(m)

        if p.dim() > 1:
            d0, d1 = p.size()
            update = update.view(d0, d1)
        else:
            update = update.view(-1)

        p.add_(update, alpha=-lr)

    @staticmethod
    def step_attn(
        p,
        grad,
        state,
        num_heads: int,
        q_per_kv: int,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            state['m'] = torch.zeros_like(p, dtype=torch.float32).view(num_heads, q_per_kv + 2, -1)
            state['v_mean'] = torch.zeros(num_heads, q_per_kv + 2, device=state['m'].device)

        m, v = state['m'], state['v_mean']

        grad = grad.view(num_heads, q_per_kv + 2, -1)

        m.lerp_(grad, weight=1.0 - beta1)

        tmp_lr = torch.mean(grad * grad, dim=2).to(m.device)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        update = (1 / (h * bias_correction1)).view(num_heads, q_per_kv + 2, -1).mul_(m)

        if p.dim() > 1:
            d0, d1 = p.size()
            update = update.view(d0, d1)
        else:
            update = update.view(-1)

        p.add_(update, alpha=-lr)

    def step_lefts(
        self,
        p,
        grad,
        state,
        lr: float,
        beta1: float,
        beta2: float,
        bias_correction1: float,
        bias_correction2_sq: float,
        eps: float,
    ) -> None:
        if len(state) == 0:
            dim = torch.tensor(p.numel(), device=p.device, dtype=torch.float32)

            reduced: bool = False
            if self.model_sharding and self.world_size > 1:
                tensor_list = [torch.zeros_like(dim) for _ in range(self.world_size)]
                dist.all_gather(tensor_list, dim)

                s, dim = 0, 0
                for d in tensor_list:
                    if d > 0:
                        s += 1
                    dim += d

                if s >= 2:
                    reduced = True

            state['m'] = torch.zeros_like(p, dtype=torch.float32)
            state['v_mean'] = torch.tensor(0.0, device=state['m'].device)
            state['dimension'] = dim
            state['reduced'] = reduced

        tmp_lr = torch.sum(grad * grad)

        if state['reduced']:
            dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)

        tmp_lr.div_(state['dimension'])

        m, v = state['m'], state['v_mean']

        m.lerp_(grad, weight=1.0 - beta1)
        v.mul_(beta2).add_(tmp_lr, alpha=1.0 - beta2)

        h = (v.sqrt() / bias_correction2_sq).add_(eps)

        stepsize = (1 / bias_correction1) / h

        update = m * stepsize

        p.add_(update, alpha=-lr)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            name = group['name']

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            bias_correction2_sq: float = math.sqrt(bias_correction2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                grad = grad.to(torch.float32)

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                if any(block in name for block in self.embed_blocks):
                    self.step_embed(
                        p, grad, state, group['lr'], beta1, beta2, bias_correction1, bias_correction2_sq, group['eps']
                    )
                elif any(block in name for block in self.qk_blocks):
                    self.step_attn_proj(
                        p,
                        grad,
                        state,
                        group['parameter_per_head'],
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )
                elif 'attn.attn.weight' in name or 'attn.qkv.weight' in name:
                    self.step_attn(
                        p,
                        grad,
                        state,
                        group['n_head'],
                        group['q_per_kv'],
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )
                else:
                    self.step_lefts(
                        p,
                        grad,
                        state,
                        group['lr'],
                        beta1,
                        beta2,
                        bias_correction1,
                        bias_correction2_sq,
                        group['eps'],
                    )

        return loss

AdaMax

Bases: BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamax.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class AdaMax(BaseOptimizer):
    r"""An Adaptive and Momental Bound Method for Stochastic Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMax'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_inf'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_inf = state['exp_avg'], state['exp_inf']

                p, grad, exp_avg, exp_inf = self.view_as_real(p, grad, exp_avg, exp_inf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                norm_buf = torch.cat(
                    (exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(group['eps']).unsqueeze_(0)),
                    dim=0,
                )
                torch.max(norm_buf, dim=0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

                p.addcdiv_(exp_avg, exp_inf, value=-step_size)

        return loss

AdamC

Bases: BaseOptimizer

Why Gradients Rapidly Increase Near the End of Training.

Set normalized=True for LayerNorm and BatchNorm layers.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamc.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class AdamC(BaseOptimizer):
    r"""Why Gradients Rapidly Increase Near the End of Training.

    Set `normalized=True` for LayerNorm and BatchNorm layers.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize
        self.max_lr: float = lr

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamC'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            wd_step_size: float = group['lr'] if not group.get('normalized') else (group['lr'] ** 2) / self.max_lr

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=wd_step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                p.addcdiv_(exp_avg / bias_correction1, de_nom, value=-group['lr'])

        return loss

AdamG

Bases: BaseOptimizer

Towards Stability of Parameter-free Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.999, 0.95)
p float

float. p for a numerator function s(x) = p * x^q.

0.2
q float

float. q for a numerator function s(x) = p * x^q.

0.24
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamg.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class AdamG(BaseOptimizer):
    r"""Towards Stability of Parameter-free Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param p: float. p for a numerator function `s(x) = p * x^q`.
    :param q: float. q for a numerator function `s(x) = p * x^q`.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.95, 0.999, 0.95),
        p: float = 0.2,
        q: float = 0.24,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(p, 'p')
        self.validate_positive(q, 'q')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.p = p
        self.q = q
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamG'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['m'] = torch.zeros_like(p)
                state['v'] = torch.zeros_like(p)
                state['r'] = torch.zeros_like(p)

    def s(self, p: torch.Tensor) -> torch.Tensor:
        r"""Numerator function f(x) = p * x^q."""
        return p.pow(self.q).mul_(self.p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = min(group['lr'], 1.0 / math.sqrt(group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                m, v, r = state['m'], state['v'], state['r']

                p, grad, m, v, r = self.view_as_real(p, grad, m, v, r)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                v.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                r.mul_(beta3).add_(self.s(v), alpha=1.0 - beta3)
                m.mul_(beta1).addcmul_(r, grad, value=1.0 - beta1)

                update = (m / bias_correction1) / (v / bias_correction2).sqrt_().add_(group['eps'])

                p.add_(update, alpha=-step_size)

        return loss

s(p)

Numerator function f(x) = p * x^q.

Source code in pytorch_optimizer/optimizer/adamg.py
81
82
83
def s(self, p: torch.Tensor) -> torch.Tensor:
    r"""Numerator function f(x) = p * x^q."""
    return p.pow(self.q).mul_(self.p)

AdaMod

Bases: BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. beta3 is for smoothing coefficient for adaptive learning rates.

(0.9, 0.99, 0.9999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamod.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
class AdaMod(BaseOptimizer):
    r"""An Adaptive and Momental Bound Method for Stochastic Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
        beta3 is for smoothing coefficient for adaptive learning rates.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.99, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMod'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_lr'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_avg_lr = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_lr']

                p, grad, exp_avg, exp_avg_sq, exp_avg_lr = self.view_as_real(p, grad, exp_avg, exp_avg_sq, exp_avg_lr)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                update = torch.full_like(de_nom, fill_value=step_size)
                update.div_(de_nom)

                exp_avg_lr.mul_(beta3).add_(update, alpha=1.0 - beta3)

                torch.min(update, exp_avg_lr, out=update)
                update.mul_(exp_avg)

                p.add_(-update)

        return loss

AdamP

Bases: BaseOptimizer

Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
delta float

float. threshold that determines whether a set of parameters is scale invariant or not.

0.1
wd_ratio float

float. relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
nesterov bool

bool. enables Nesterov momentum.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamp.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class AdamP(BaseOptimizer):
    r"""Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param delta: float. threshold that determines whether a set of parameters is scale invariant or not.
    :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
        on scale-variant parameters.
    :param nesterov: bool. enables Nesterov momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        nesterov: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamP'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                inv_de_nom = exp_avg_sq.rsqrt().add_(group['eps']).mul_(bias_correction2_sq)

                perturb = exp_avg.clone()

                if group.get('cautious'):
                    self.apply_cautious(perturb, grad)

                if group['nesterov']:
                    perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
                else:
                    perturb.mul_(inv_de_nom)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    perturb, wd_ratio = projection(
                        p,
                        grad,
                        perturb,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio,
                )

                p.add_(perturb, alpha=-step_size)

        return loss

AdamS

Bases: BaseOptimizer

Adam with stable weight decay.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0001
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adams.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class AdamS(BaseOptimizer):
    r"""Adam with stable weight decay.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamS'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if group['ams_bound']:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    exp_avg_sq_hat = max_exp_avg_sq
                else:
                    exp_avg_sq_hat = exp_avg_sq

                exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError()

        exp_avg_sq_hat_mean: float = math.sqrt(exp_avg_sq_hat_sum / param_size) + self.defaults['eps']

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / exp_avg_sq_hat_mean,
                )

                exp_avg_sq_hat = state['max_exp_avg_sq'] if group['ams_bound'] else state['exp_avg_sq']
                exp_avg_sq_hat.div_(bias_correction2)

                de_nom = exp_avg_sq_hat.sqrt().add_(group['eps'])

                p.addcdiv_(state['exp_avg'], de_nom, value=-step_size)

        return loss

AdamWSN

Bases: BaseOptimizer

Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees.

.. code-block:: python

sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)]
sn_param_ids = [id(p) for p in sn_params]
regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids]
param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}]
optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size)

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
subset_size int

int. If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B and so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1 to use the recommended subset size as stated above.

-1
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/snsm.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class AdamWSN(BaseOptimizer):
    r"""Lean and Mean Adaptive Optimization via Subset-Norm and Subspace-Momentum with Convergence Guarantees.

    .. code-block:: python

        sn_params = [module.weight for module in model.modules() if isinstance(module, nn.Linear)]
        sn_param_ids = [id(p) for p in sn_params]
        regular_params = [p for p in model.parameters() if id(p) not in sn_param_ids]
        param_groups = [{'params': regular_params, 'sn': False}, {'params': sn_params, 'sn': True}]
        optimizer = AdamWSN(param_groups, lr=args.lr, weight_decay=args.weight_decay, subset_size=args.subset_size)

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param subset_size: int. If you do not know what subset_size to set, a good rule of thumb is to set it as d/2 where
        d is the hidden dimension of your transformer model. For example, the hidden dimension is 4096 for Llama 7B and
        so a good subset_size could be 2048. You can leave the subset_size argument to its default value of -1 to use
        the recommended subset size as stated above.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        subset_size: int = -1,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'subset_size': subset_size,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamWSN'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)

                if group.get('sn'):
                    size: int = grad.numel()

                    if 'subset_size' not in state:
                        state['subset_size'] = closest_smaller_divisor_of_n_to_k(
                            size,
                            (
                                group['subset_size']
                                if group['subset_size'] > 0
                                else int(math.sqrt(size) / abs(int(group['subset_size'])))
                            ),
                        )

                    reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
                    second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True)  # fmt: skip
                    state['exp_avg_sq'] = torch.zeros_like(second_moment_update)
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                size = grad.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                if group.get('sn'):
                    reshaped_grad = grad.view(size // state['subset_size'], state['subset_size'])
                    second_moment_update = torch.sum(reshaped_grad ** 2, dim=1, keepdim=True)  # fmt: skip
                else:
                    second_moment_update = grad.pow(2)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).add_(second_moment_update, alpha=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if group.get('sn'):
                    numerator = exp_avg.view(size // state['subset_size'], state['subset_size'])
                    norm_grad = (numerator / de_nom).reshape(p.shape)
                    p.add_(norm_grad, alpha=-step_size)
                else:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

Adan

Bases: BaseOptimizer

Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. decoupled weight decay.

False
max_grad_norm float

float. max gradient norm to clip.

0.0
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adan.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class Adan(BaseOptimizer):
    r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param max_grad_norm: float. max gradient norm to clip.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        max_grad_norm: float = 0.0,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.max_grad_norm = max_grad_norm
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'max_grad_norm': max_grad_norm,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adan'

    def init_group(self, group: GROUP, **kwargs) -> None:
        clip_global_grad_norm = kwargs.get('clip_global_grad_norm')

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)
                state['previous_grad'] = grad.clone().mul_(-clip_global_grad_norm)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        clip_global_grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group, clip_global_grad_norm=clip_global_grad_norm)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            bias_correction3_sq: float = math.sqrt(self.debias(beta3, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']
                grad_diff = state['previous_grad']

                p, grad, exp_avg, exp_avg_sq, exp_avg_diff, grad_diff = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, exp_avg_diff, grad_diff
                )

                grad.mul_(clip_global_grad_norm)

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                grad_diff.add_(grad)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)

                grad_diff.mul_(beta2).add_(grad)
                exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction3_sq).add_(group['eps'])

                if group['weight_decouple']:
                    p.mul_(1.0 - group['lr'] * group['weight_decay'])

                p.addcdiv_(exp_avg, de_nom, value=-group['lr'] / bias_correction1)
                p.addcdiv_(exp_avg_diff, de_nom, value=-group['lr'] * beta2 / bias_correction2)

                if not group['weight_decouple']:
                    p.div_(1.0 + group['lr'] * group['weight_decay'])

                grad.neg_()
                state['previous_grad'].copy_(
                    torch.view_as_complex(grad) if torch.is_complex(state['previous_grad']) else grad
                )

        return loss

AdaNorm

Bases: BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adanorm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class AdaNorm(BaseOptimizer):
    r"""Symbolic Discovery of Optimization Algorithms.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.99),
        r: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'r': r,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaNorm'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)
                state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'],
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=True,
                    exp_grad_norm=state['exp_grad_norm'],
                    r=group['r'],
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_var.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

AdaPNM

Bases: BaseOptimizer

Adam + Positive-Negative Momentum Optimizers.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use weight_decouple.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the ams_bound variant.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adapnm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class AdaPNM(BaseOptimizer):
    r"""Adam + Positive-Negative Momentum Optimizers.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use weight_decouple.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the ams_bound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaPNM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['neg_exp_avg'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            beta1_p2: float = beta1 ** 2  # fmt: skip
            noise_norm: float = math.sqrt((1 + beta3) ** 2 + beta3 ** 2)  # fmt: skip

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False), step_size=group['lr'], bias_correction1=bias_correction1
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg_sq = state['exp_avg_sq']

                if group['step'] % 2 == 1:
                    exp_avg, neg_exp_avg = state['exp_avg'], state['neg_exp_avg']
                else:
                    exp_avg, neg_exp_avg = state['neg_exp_avg'], state['exp_avg']

                p, grad, exp_avg, neg_exp_avg, exp_avg_sq = self.view_as_real(
                    p, grad, exp_avg, neg_exp_avg, exp_avg_sq
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1_p2).add_(s_grad, alpha=1.0 - beta1_p2)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                pn_momentum = exp_avg.mul(1.0 + beta3).add_(neg_exp_avg, alpha=-beta3).mul_(1.0 / noise_norm)

                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        return loss

AdaShift

Bases: BaseOptimizer

Decorrelation and Convergence of Adaptive Learning Rate Methods.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
keep_num int

int. number of gradients used to compute first moment estimation.

10
reduce_func Optional[Callable]

Optional[Callable]. function applied to squared gradients to further reduce the correlation. If None, no function is applied.

max
eps float

float. term added to the denominator to improve numerical stability.

1e-10
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adashift.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class AdaShift(BaseOptimizer):
    r"""Decorrelation and Convergence of Adaptive Learning Rate Methods.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param keep_num: int. number of gradients used to compute first moment estimation.
    :param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
        If None, no function is applied.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        keep_num: int = 10,
        reduce_func: Optional[Callable] = torch.max,
        eps: float = 1e-10,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(keep_num, 'keep_num')
        self.validate_non_negative(eps, 'eps')

        self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x
        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaShift'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['grad_queue'] = deque([grad.clone()], maxlen=group['keep_num'])
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            exp_weight_sum: int = sum(beta1 ** i for i in range(group['keep_num']))  # fmt: skip
            first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum
            last_grad_weight: float = 1.0 / exp_weight_sum

            bias_correction: float = self.debias(beta2, group['step'] - group['keep_num'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_queue = state['grad_queue']
                grad_queue.append(grad.clone())

                if len(grad_queue) != group['keep_num']:
                    continue

                offset_grad = grad_queue[0]

                exp_avg = state['exp_avg']
                exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight)

                reduced_grad_sq = self.reduce_func(offset_grad.pow_(2))

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

                update = exp_avg.clone()
                if group.get('cautious'):
                    self.apply_cautious(update, grad)

                update.div_(exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps']))

                p.add_(update, alpha=-group['lr'])

        return loss

AdaSmooth

Bases: BaseOptimizer

An Adaptive Learning Rate Method based on Effective Ratio.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.5, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adasmooth.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class AdaSmooth(BaseOptimizer):
    r"""An Adaptive Learning Rate Method based on Effective Ratio.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.5, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaSmooth'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['prev_param'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)
                state['n'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s, n, prev_param, exp_avg_sq = state['s'], state['n'], state['prev_param'], state['exp_avg_sq']

                p, grad, s, n, prev_param, exp_avg_sq = self.view_as_real(p, grad, s, n, prev_param, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p_diff = p - prev_param

                s.add_(p_diff)
                n.add_(p_diff.abs())

                c = s.sum().abs_().div_(n.sum())  # e_t
                c.mul_(beta2 - beta1).add_(1.0 - beta2)

                c_p2 = c.pow(2)

                exp_avg_sq.mul_(1.0 - c_p2).addcmul_(grad, grad, value=c_p2)

                step_size = torch.full_like(exp_avg_sq, fill_value=group['lr'])
                step_size.div_((exp_avg_sq + group['eps']).sqrt()).mul_(grad)

                p.add_(-step_size)

                state['prev_param'].copy_(torch.view_as_complex(p) if torch.is_complex(state['prev_param']) else p)

        return loss

AdEMAMix

Bases: BaseOptimizer

Better, Faster, Older.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999, 0.9999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
alpha float

float. usually between 4 and 10 would work well.

5.0
t_alpha_beta3 Optional[float]

Optional[float]. total number of iterations is preferred when needed.

None
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ademamix.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class AdEMAMix(BaseOptimizer):
    r"""Better, Faster, Older.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param alpha: float. usually between 4 and 10 would work well.
    :param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        alpha: float = 5.0,
        t_alpha_beta3: Optional[float] = None,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(alpha, 'alpha')
        self.validate_non_negative(t_alpha_beta3, 't_alpha_beta3')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'alpha': alpha,
            't_alpha_beta3': t_alpha_beta3,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdEMAMix'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_slow'] = torch.zeros_like(p)

    @staticmethod
    def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
        return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)

    @staticmethod
    def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
        if t_alpha_beta3 is None:
            return beta3

        log_beta1, log_beta3 = math.log(beta1), math.log(beta3)

        return min(
            math.exp(
                log_beta1 * log_beta3 / ((1.0 - step / t_alpha_beta3) * log_beta3 + (step / t_alpha_beta3) * log_beta1)
            ),
            beta3,
        )

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] / bias_correction1

            alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
            beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq, exp_avg_slow = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_slow']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                update = exp_avg.clone()
                if group.get('cautious'):
                    self.apply_cautious(update, grad)

                if group.get('stable_adamw'):
                    step_size /= self.get_stable_adamw_rms(grad, exp_avg_sq)

                update.add_(exp_avg_slow, alpha=alpha_t).div_(de_nom)

                p.add_(update, alpha=-step_size)

        return loss

SimplifiedAdEMAMix

Bases: BaseOptimizer

Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.99, 0.95)
alpha float

float. coefficient for mixing the current gradient and EMA.

0.0
beta1_warmup Optional[int]

Optional[int]. number of warmup steps used to increase beta1.

None
min_beta1 float

float. minimum value of beta1 to start from.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ademamix.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
class SimplifiedAdEMAMix(BaseOptimizer):
    r"""Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param alpha: float. coefficient for mixing the current gradient and EMA.
    :param beta1_warmup: Optional[int]. number of warmup steps used to increase beta1.
    :param min_beta1: float. minimum value of beta1 to start from.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        betas: BETAS = (0.99, 0.95),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        alpha: float = 0.0,
        beta1_warmup: Optional[int] = None,
        min_beta1: float = 0.9,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(alpha, 'alpha')
        self.validate_non_negative(min_beta1, 'min_beta1')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'beta1_warmup': beta1_warmup,
            'min_beta1': min_beta1,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SimplifiedAdEMAMix'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['num_sum'] = 0.0
                state['den_sum'] = 0.0

    @staticmethod
    def linear_hl_warmup_scheduler(step: int, beta_end: float, beta_start: float = 0.0, warmup: int = 1) -> float:

        def f(beta: float, eps: float = 1e-8) -> float:
            return math.log(0.5) / math.log(beta + eps) - 1.0

        def f_inv(t: float) -> float:
            return math.pow(0.5, 1.0 / (t + 1))

        if step < warmup:
            a: float = step / float(warmup)
            return f_inv((1.0 - a) * f(beta_start) + a * f(beta_end))

        return beta_end

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            if group['beta1_warmup']:
                beta1 = self.linear_hl_warmup_scheduler(
                    group['step'], beta_end=beta1, beta_start=group['min_beta1'], warmup=group['beta1_warmup']
                )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                state['num_sum'] = beta1 * state['num_sum'] + 1.0
                state['den_sum'] = beta2 * state['den_sum'] + (1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(math.sqrt(state['den_sum']) * group['eps'])

                update = (group['alpha'] * grad + exp_avg).div_(de_nom).div_(math.sqrt(state['den_sum']))

                p.add_(update, alpha=-group['lr'])

        return loss

ADOPT

Bases: BaseOptimizer

Modified Adam Can Converge with Any β2 with the Optimal Rate.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.9999)
clip_lambda Callable[[float], float]

Callable[[float], float]. function to clip gradient. default is step ** 0.25

lambda step: pow(step, 0.25)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adopt.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class ADOPT(BaseOptimizer):
    r"""Modified Adam Can Converge with Any β2 with the Optimal Rate.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param clip_lambda: Callable[[float], float]. function to clip gradient. default is `step ** 0.25`
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.9999),
        clip_lambda: Callable[[float], float] = lambda step: math.pow(step, 0.25),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.clip_lambda = clip_lambda
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ADOPT'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']
            lr: float = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['step'] == 1:
                    exp_avg_sq.addcmul_(grad, grad.conj())
                    continue

                de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

                normed_grad = grad.div(de_nom)
                if self.clip_lambda is not None:
                    clip = self.clip_lambda(group['step'])
                    normed_grad.clamp_(-clip, clip)

                exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

                if group.get('cautious'):
                    update = exp_avg.clone()
                    self.apply_cautious(update, normed_grad)
                else:
                    update = exp_avg

                if group.get('stable_adamw'):
                    lr /= self.get_stable_adamw_rms(grad, exp_avg_sq)

                p.add_(update, alpha=-lr)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

        return loss

agc

agc(p, grad, agc_eps=0.001, agc_clip_val=0.01, eps=1e-06)

Clip gradient values in excess of the unit wise norm.

Parameters:

Name Type Description Default
p Tensor

torch.Tensor. parameter.

required
grad Tensor

torch.Tensor, gradient.

required
agc_eps float

float. agc epsilon to clip the norm of parameter.

0.001
agc_clip_val float

float. norm clip.

0.01
eps float

float. simple stop from div by zero and no relation to standard optimizer eps.

1e-06
Source code in pytorch_optimizer/optimizer/agc.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def agc(
    p: torch.Tensor, grad: torch.Tensor, agc_eps: float = 1e-3, agc_clip_val: float = 1e-2, eps: float = 1e-6
) -> torch.Tensor:
    r"""Clip gradient values in excess of the unit wise norm.

    :param p: torch.Tensor. parameter.
    :param grad: torch.Tensor, gradient.
    :param agc_eps: float. agc epsilon to clip the norm of parameter.
    :param agc_clip_val: float. norm clip.
    :param eps: float. simple stop from div by zero and no relation to standard optimizer eps.
    """
    max_norm = unit_norm(p).clamp_min_(agc_eps).mul_(agc_clip_val)
    g_norm = unit_norm(grad).clamp_min_(eps)

    clipped_grad = grad * (max_norm / g_norm)

    return torch.where(g_norm > max_norm, clipped_grad, grad)

AggMo

Bases: BaseOptimizer

Aggregated Momentum: Stability Through Passive Damping.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.0, 0.9, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/aggmo.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class AggMo(BaseOptimizer):
    r"""Aggregated Momentum: Stability Through Passive Damping.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.0, 0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AggMo'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']}

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            betas = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                for beta in betas:
                    buf = state['momentum_buffer'][beta]
                    buf.mul_(beta).add_(grad)

                    p.add_(buf, alpha=-group['lr'] / len(betas))

        return loss

Aida

Bases: BaseOptimizer

A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
k int

int. number of vector projected per iteration.

2
xi float

float. term used in vector projections to avoid division by zero.

1e-20
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/aida.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class Aida(BaseOptimizer):
    r"""A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param k: int. number of vector projected per iteration.
    :param xi: float. term used in vector projections to avoid division by zero.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: number of SMA threshold (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        k: int = 2,
        xi: float = 1e-20,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(eps, 'eps')

        self.k = k
        self.xi = xi
        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Aida'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                proj_g = grad.detach().clone()
                proj_m = exp_avg.detach().clone()

                for _ in range(self.k):
                    proj_sum_gm = torch.sum(torch.mul(proj_g, proj_m))

                    scalar_g = proj_sum_gm / (torch.sum(torch.pow(proj_g, 2)).add_(self.xi))
                    scalar_m = proj_sum_gm / (torch.sum(torch.pow(proj_m, 2)).add_(self.xi))

                    proj_g.mul_(scalar_g)
                    proj_m.mul_(scalar_m)

                grad_residual = proj_m - proj_g
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )

                if not group['rectify']:
                    de_nom.div_(bias_correction2_sq)
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

Alice

Bases: BaseOptimizer

Adaptive low-dimensional subspace estimation.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.02
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. beta3=0 for Alice-0 optimizer.

(0.9, 0.9, 0.999)
alpha float

float. scaler.

0.3
alpha_c float

float. compensation scaler.

0.4
update_interval int

int. update interval.

200
rank int

int. rank.

256
gamma float

limiter threshold.

1.01
leading_basis int

int. leading basis.

40
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/racs.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class Alice(BaseOptimizer):
    r"""Adaptive low-dimensional subspace estimation.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
        beta3=0 for Alice-0 optimizer.
    :param alpha: float. scaler.
    :param alpha_c: float. compensation scaler.
    :param update_interval: int. update interval.
    :param rank: int. rank.
    :param gamma: limiter threshold.
    :param leading_basis: int. leading basis.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 0.02,
        betas: BETAS = (0.9, 0.9, 0.999),
        alpha: float = 0.3,
        alpha_c: float = 0.4,
        update_interval: int = 200,
        rank: int = 256,
        gamma: float = 1.01,
        leading_basis: int = 40,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_range(alpha_c, 'alpha_c', 0.0, 1.0)
        self.validate_positive(update_interval, 'update_interval')
        self.validate_positive(rank, 'rank')
        self.validate_positive(gamma, 'gamma')
        self.validate_positive(leading_basis, 'leading_basis')
        self.validate_non_negative(rank - leading_basis, 'rank - leading_basis')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'alpha_c': alpha_c,
            'update_interval': update_interval,
            'rank': rank,
            'gamma': gamma,
            'leading_basis': leading_basis,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Alice'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @staticmethod
    def subspace_iteration(
        a: torch.Tensor, mat: torch.Tensor, num_steps: int = 1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Perform subspace iteration."""
        u = mat
        for _ in range(num_steps):
            u, _ = torch.linalg.qr(a @ u)

        return torch.linalg.eigh(u.T @ a @ u)

    def switch(self, q: torch.Tensor, u_prev: torch.Tensor, rank: int, leading_basis: int) -> torch.Tensor:
        vals, vecs = self.subspace_iteration(q.to(torch.float32), u_prev.to(torch.float32), num_steps=1)

        leading_indices = torch.argsort(vals, descending=True)[:leading_basis]
        u_t1 = vecs[:, leading_indices]

        u_c, _ = torch.linalg.qr(torch.eye(q.shape[0], device=q.device) - u_t1 @ u_t1.T)
        u_t2 = u_c[:, :rank - leading_basis]  # fmt: skip

        return torch.cat([u_t1, u_t2], dim=1).to(q.dtype)

    @staticmethod
    def compensation(
        grad: torch.Tensor,
        u: torch.Tensor,
        p: torch.Tensor,
        phi: torch.Tensor,
        gamma: float,
        decay_rate: float,
        rank: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        m, n = grad.shape

        sigma = u.T @ grad

        p.mul_(decay_rate).add_(grad.pow(2).sum(dim=0) - sigma.pow(2).sum(dim=0), alpha=1.0 - decay_rate).clamp_min_(
            1e-8
        )

        d = torch.zeros_like(grad)
        diag_len: int = min(m, n)
        d[torch.arange(diag_len), torch.arange(diag_len)] = 1.0 / p.sqrt()[:diag_len]

        c_t = math.sqrt(m - rank) * (grad - u @ sigma) * d if m >= rank else torch.zeros_like(grad)

        n = gamma / max(torch.norm(c_t) / phi, gamma) if phi.item() > 0 else torch.ones_like(phi)

        c_t.mul_(n)
        phi = torch.norm(c_t)

        return c_t, phi

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']
            rank, leading_basis = group['rank'], group['leading_basis']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if grad.ndim < 2:
                    grad = grad.reshape(len(grad), 1)
                elif grad.ndim > 2:
                    grad = grad.reshape(len(grad), -1)

                if len(state) == 0:
                    m, n = grad.shape

                    state['U'] = torch.zeros((m, rank), dtype=p.dtype, device=p.device)
                    state['Q'] = torch.zeros((rank, rank), dtype=p.dtype, device=p.device)

                    state['m'] = torch.zeros((rank, n), dtype=p.dtype, device=p.device)
                    state['v'] = torch.zeros((rank, n), dtype=p.dtype, device=p.device)

                    state['p'] = torch.zeros((n,), dtype=p.dtype, device=p.device)
                    state['phi'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                q, u, m, v = state['Q'], state['U'], state['m'], state['v']

                if group['step'] == 1 or group['step'] % group['update_interval'] == 0:
                    q_t = beta3 * (u @ q @ u.T) + (1.0 - beta3) * (grad @ grad.T)
                    u = self.switch(q_t, u, rank, leading_basis)

                sigma = u.T @ grad

                q.mul_(beta3).add_(sigma @ sigma.T, alpha=1.0 - beta3)
                m.mul_(beta1).add_(sigma, alpha=1.0 - beta1)
                v.mul_(beta2).add_(sigma.pow(2), alpha=1.0 - beta2)

                c_t, phi = self.compensation(grad, u, state['p'], state['phi'], group['gamma'], beta1, rank)

                update = u @ (m / v.sqrt())
                update.add_(c_t, alpha=group['alpha_c'])

                p.add_(update.view_as(p), alpha=-group['lr'] * group['alpha'])

                state['phi'] = phi

        return loss

subspace_iteration(a, mat, num_steps=1) staticmethod

Perform subspace iteration.

Source code in pytorch_optimizer/optimizer/racs.py
215
216
217
218
219
220
221
222
223
224
@staticmethod
def subspace_iteration(
    a: torch.Tensor, mat: torch.Tensor, num_steps: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Perform subspace iteration."""
    u = mat
    for _ in range(num_steps):
        u, _ = torch.linalg.qr(a @ u)

    return torch.linalg.eigh(u.T @ a @ u)

AliG

Bases: BaseOptimizer

Adaptive Learning Rates for Interpolation with Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
max_lr Optional[float]

Optional[float]. max learning rate.

None
projection_fn Optional[Callable]

Callable. projection function to enforce constraints.

None
momentum float

float. momentum.

0.0
adjusted_momentum bool

bool. if True, use pytorch-like momentum, instead of standard Nesterov momentum.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/alig.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class AliG(BaseOptimizer):
    r"""Adaptive Learning Rates for Interpolation with Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param max_lr: Optional[float]. max learning rate.
    :param projection_fn: Callable. projection function to enforce constraints.
    :param momentum: float. momentum.
    :param adjusted_momentum: bool. if True, use pytorch-like momentum, instead of standard Nesterov momentum.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        max_lr: Optional[float] = None,
        projection_fn: Optional[Callable] = None,
        momentum: float = 0.0,
        adjusted_momentum: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(max_lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)

        self.projection_fn = projection_fn
        self.maximize = maximize

        defaults: DEFAULTS = {'max_lr': max_lr, 'adjusted_momentum': adjusted_momentum, 'momentum': momentum}

        super().__init__(params, defaults)

        if self.projection_fn is not None:
            self.projection_fn()

    def __str__(self) -> str:
        return 'AliG'

    def init_group(self, group: GROUP, **kwargs) -> None:
        momentum: float = kwargs.get('momentum')

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and momentum > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def compute_step_size(self, loss: float) -> float:
        r"""Compute step_size."""
        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.add_(1e-6)

        return loss / global_grad_norm.item()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        if closure is None:
            raise NoClosureError('AliG', '(e.g. `optimizer.step(lambda: float(loss))`).')

        loss = closure()

        un_clipped_step_size: float = self.compute_step_size(loss)

        for group in self.param_groups:
            momentum = group['momentum']

            if 'step' not in group:
                self.init_group(group, momentum=momentum)
                group['step'] = 1
            else:
                group['step'] += 1

            step_size = group['step_size'] = (
                min(un_clipped_step_size, group['max_lr']) if group['max_lr'] is not None else un_clipped_step_size
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p, grad, buffer = self.view_as_real(p, grad, state.get('momentum_buffer', None))

                p.add_(grad, alpha=-step_size)

                if buffer is not None:
                    if group['adjusted_momentum']:
                        buffer.mul_(momentum).sub_(grad)
                        p.add_(buffer, alpha=step_size * momentum)
                    else:
                        buffer.mul_(momentum).add_(grad, alpha=-step_size)
                        p.add_(buffer, alpha=momentum)

            if self.projection_fn is not None:
                self.projection_fn()

        return loss

compute_step_size(loss)

Compute step_size.

Source code in pytorch_optimizer/optimizer/alig.py
74
75
76
77
78
79
80
@torch.no_grad()
def compute_step_size(self, loss: float) -> float:
    r"""Compute step_size."""
    global_grad_norm = get_global_gradient_norm(self.param_groups)
    global_grad_norm.add_(1e-6)

    return loss / global_grad_norm.item()

Amos

Bases: BaseOptimizer

An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
beta float

float. A float slightly < 1. We recommend setting 1 - beta to the same order of magnitude as the learning rate. similarity with beta2 in Adam.

0.999
momentum float

float. Exponential decay rate for optional moving average of updates.

0.0
extra_l2 float

float. Additional L2 regularization.

0.0
c_coef float

float. Coefficient for decay_factor_c.

0.25
d_coef float

float. Coefficient for decay_factor_d.

0.25
eps float

float. term added to the denominator to improve numerical stability.

1e-18
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/amos.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class Amos(BaseOptimizer):
    r"""An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. A float slightly < 1. We recommend setting `1 - beta` to the same order of magnitude as the
        learning rate. similarity with beta2 in Adam.
    :param momentum: float. Exponential decay rate for optional moving average of updates.
    :param extra_l2: float. Additional L2 regularization.
    :param c_coef: float. Coefficient for decay_factor_c.
    :param d_coef: float. Coefficient for decay_factor_d.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        beta: float = 0.999,
        momentum: float = 0.0,
        extra_l2: float = 0.0,
        c_coef: float = 0.25,
        d_coef: float = 0.25,
        eps: float = 1e-18,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(extra_l2, 'extra_l2')
        self.validate_non_negative(eps, 'eps')

        self.c_coef = c_coef
        self.d_coef = d_coef
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'beta': beta,
            'momentum': momentum,
            'extra_l2': extra_l2,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Amos'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg_sq'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['decay'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['momentum'] > 0.0:
                    state['exp_avg'] = torch.zeros_like(p)

    @staticmethod
    def get_scale(p: torch.Tensor) -> float:
        r"""Get expected scale for model weights."""
        if len(p.shape) == 1:
            return 0.5
        if len(p.shape) == 2:
            return math.sqrt(2 / p.size(1))
        return math.sqrt(1 / p.size(1))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum, beta = group['momentum'], group['beta']

            lr_sq: float = math.sqrt(group['lr'])
            bias_correction: float = self.debias(beta, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                g2 = grad.pow(2).mean()
                init_lr: float = group['lr'] * self.get_scale(p)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta).add_(g2, alpha=1.0 - beta)

                r_v_hat = bias_correction / (exp_avg_sq + group['eps'])

                b = state['decay']
                decay_factor_c = torch.rsqrt(1.0 + self.c_coef * lr_sq * b)
                decay_factor_d = torch.reciprocal(1.0 + self.d_coef * math.sqrt(init_lr) * b)

                gamma = decay_factor_c * (group['lr'] ** 2) * r_v_hat * g2

                update = p.clone()
                update.mul_((gamma - group['extra_l2']) / 2.0)
                update.add_(r_v_hat.sqrt() * grad, alpha=init_lr)
                update.mul_(decay_factor_d)

                b.mul_(1.0 + gamma).add_(gamma)

                if momentum > 0.0:
                    exp_avg = state['exp_avg']
                    exp_avg.mul_(momentum).add_(update, alpha=1.0 - momentum)

                    update.copy_(exp_avg)

                p.add_(-update)

        return loss

get_scale(p) staticmethod

Get expected scale for model weights.

Source code in pytorch_optimizer/optimizer/amos.py
78
79
80
81
82
83
84
85
@staticmethod
def get_scale(p: torch.Tensor) -> float:
    r"""Get expected scale for model weights."""
    if len(p.shape) == 1:
        return 0.5
    if len(p.shape) == 2:
        return math.sqrt(2 / p.size(1))
    return math.sqrt(1 / p.size(1))

APOLLO

Bases: BaseOptimizer

SGD-like Memory, AdamW-level Performance.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
correct_bias bool

bool. Whether to correct bias in Adam.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/apollo.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
class APOLLO(BaseOptimizer):
    r"""SGD-like Memory, AdamW-level Performance.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param correct_bias: bool. Whether to correct bias in Adam.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        betas: BETAS = (0.9, 0.999),
        scale_type: SCALE_TYPE = 'tensor',
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        correct_bias: bool = True,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'scale_type': scale_type,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'correct_bias': correct_bias,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'APOLLO'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            step_size: float = group['lr']
            if group['correct_bias']:
                bias_correction1: float = self.debias(beta1, group['step'])
                bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
                step_size *= bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if 'rank' in group and p.dim() > 1:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=group['rank'],
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'], from_random_matrix=True)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom
                if 'rank' in group and p.dim() > 1:
                    if group['scale_type'] == 'channel':
                        norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
                        scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
                        if norm_dim == 1:
                            scaling_factor = scaling_factor.unsqueeze(1)
                    else:
                        scaling_factor = torch.norm(norm_grad) / (torch.norm(grad) + 1e-8)

                    scaling_grad = grad * scaling_factor

                    scaling_grad_norm = torch.norm(scaling_grad)
                    if 'scaling_grad' in state:
                        limiter = (
                            max(
                                scaling_grad_norm / (state['scaling_grad'] + 1e-8),
                                1.01,
                            )
                            / 1.01
                        )

                        scaling_grad.div_(limiter)
                        scaling_grad_norm.div_(limiter)

                    state['scaling_grad'] = scaling_grad_norm

                    norm_grad = scaling_grad * np.sqrt(group['scale'])
                    norm_grad = state['projector'].project_back(norm_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

ApolloDQN

Bases: BaseOptimizer

An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
init_lr Optional[float]

Optional[float]. initial learning rate (default lr / 1000).

1e-05
beta float

float. coefficient used for computing running averages of gradient.

0.9
rebound str

str. rectified bound for diagonal hessian. (constant, belief).

'constant'
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decay_type str

str. type of weight decay. (l2, decoupled, stable).

'l2'
warmup_steps int

int. number of warmup steps.

500
eps float

float. term added to the denominator to improve numerical stability.

0.0001
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/apollo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class ApolloDQN(BaseOptimizer):
    r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param init_lr: Optional[float]. initial learning rate (default lr / 1000).
    :param beta: float. coefficient used for computing running averages of gradient.
    :param rebound: str. rectified bound for diagonal hessian. (constant, belief).
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decay_type: str. type of weight decay. (l2, decoupled, stable).
    :param warmup_steps: int. number of warmup steps.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        init_lr: Optional[float] = 1e-5,
        beta: float = 0.9,
        rebound: str = 'constant',
        weight_decay: float = 0.0,
        weight_decay_type: str = 'l2',
        warmup_steps: int = 500,
        eps: float = 1e-4,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_options(rebound, 'rebound', ['constant', 'belief'])
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_options(weight_decay_type, 'weight_decay_type', ['l2', 'decoupled', 'stable'])
        self.validate_non_negative(eps, 'eps')

        self.lr = lr
        self.warmup_steps = warmup_steps
        self.init_lr: float = init_lr if init_lr is not None else lr / 1000.0
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'init_lr': self.init_lr,
            'beta': beta,
            'rebound': rebound,
            'weight_decay': weight_decay,
            'weight_decay_type': weight_decay_type,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ApolloDQN'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg_grad'] = torch.zeros_like(p)
                state['approx_hessian'] = torch.zeros_like(p)
                state['update'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            current_lr: float = (
                group['lr']
                if group['step'] >= self.warmup_steps
                else (self.lr - group['init_lr']) * group['step'] / self.warmup_steps + group['init_lr']
            )

            weight_decay, eps = group['weight_decay'], group['eps']

            bias_correction: float = self.debias(group['beta'], group['step'])
            alpha: float = (1.0 - group['beta']) / bias_correction

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg_grad, b, d_p = state['exp_avg_grad'], state['approx_hessian'], state['update']

                p, grad, exp_avg_grad, b, d_p = self.view_as_real(p, grad, exp_avg_grad, b, d_p)

                if weight_decay > 0.0 and group['weight_decay_type'] == 'l2':
                    grad.add_(p, alpha=weight_decay)

                delta_grad = grad - exp_avg_grad
                if group['rebound'] == 'belief':
                    rebound = delta_grad.norm(p=np.inf)
                else:
                    rebound = 1e-2
                    eps /= rebound

                exp_avg_grad.add_(delta_grad, alpha=alpha)

                de_nom = d_p.norm(p=4).add_(eps)
                d_p.div_(de_nom)

                v_sq = d_p.mul(d_p)
                delta = delta_grad.div_(de_nom).mul_(d_p).sum().mul(-alpha) - b.mul(v_sq).sum()

                b.addcmul_(v_sq, delta)

                de_nom = b.abs().clamp_(min=rebound)
                if group['rebound'] == 'belief':
                    de_nom.add_(eps / alpha)

                d_p.copy_(exp_avg_grad.div(de_nom))

                if weight_decay > 0.0 and group['weight_decay_type'] != 'l2':
                    if group['weight_decay_type'] == 'stable':
                        weight_decay /= de_nom.mean().item()

                    d_p.add_(p, alpha=weight_decay)

                p.add_(d_p, alpha=-current_lr)

        return loss

AvaGrad

Bases: BaseOptimizer

Domain-independent Dominance of Adaptive Methods.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

0.1
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/avagrad.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class AvaGrad(BaseOptimizer):
    r"""Domain-independent Dominance of Adaptive Methods.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-1,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'gamma': None,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AvaGrad'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
            prev_bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step'] - 1))

            if group['step'] > 1:
                step_size: float = self.apply_adam_debias(
                    adam_debias=group.get('adam_debias', False),
                    step_size=group['gamma'] * group['lr'],
                    bias_correction1=bias_correction1,
                )

            squared_norm: float = 0.0
            num_params: float = 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                sqrt_exp_avg_sq = exp_avg_sq.sqrt()

                if group['step'] > 1:
                    de_nom = sqrt_exp_avg_sq.div(prev_bias_correction2_sq).add_(group['eps'])

                    p.addcdiv_(exp_avg, de_nom, value=-step_size)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                param_wise_lr = sqrt_exp_avg_sq.div_(bias_correction2_sq).add_(group['eps'])
                squared_norm += param_wise_lr.norm(-2) ** -2
                num_params += param_wise_lr.numel()

            group['gamma'] = 0.0 if num_params == 0.0 else 1.0 / math.sqrt(squared_norm / num_params)

        return loss

BSAM

Bases: BaseOptimizer

SAM as an Optimal Relaxation of Bayes.

Example:

Here's an example::

    model = YourModel()
    optimizer = BSAM(model.parameters(), ...)

    def closure():
        loss = loss_function(output, model(input))
        loss.backward()
        return loss

    for input, output in data:
        loss = loss_function(output, model(input))
        loss.backward()

        optimizer.step(closure)
        optimizer.zero_grad()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
num_data int

int. number of training data.

required
lr float

float. learning rate.

0.5
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0001
rho float

float. size of the neighborhood for computing the max loss.

0.05
adaptive bool

bool. element-wise Adaptive SAM.

False
damping float

float. damping to stabilize the method.

0.1
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
class BSAM(BaseOptimizer):
    r"""SAM as an Optimal Relaxation of Bayes.

    Example:
    -------
        Here's an example::

            model = YourModel()
            optimizer = BSAM(model.parameters(), ...)

            def closure():
                loss = loss_function(output, model(input))
                loss.backward()
                return loss

            for input, output in data:
                loss = loss_function(output, model(input))
                loss.backward()

                optimizer.step(closure)
                optimizer.zero_grad()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param num_data: int. number of training data.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param rho: float. size of the neighborhood for computing the max loss.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param damping: float. damping to stabilize the method.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        num_data: int,
        lr: float = 5e-1,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 1e-4,
        rho: float = 0.05,
        adaptive: bool = False,
        damping: float = 0.1,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(rho, 'rho')
        self.validate_non_negative(num_data, 'num_data')
        self.validate_non_negative(damping, 'damping')

        self.num_data = num_data
        self.damping = damping

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'rho': rho,
            'adaptive': adaptive,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'bSAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            state = self.state[p]

            if 's' not in state:
                state['s'] = torch.ones_like(p)
                state['noisy_gradient'] = torch.zeros_like(p.grad)
                state['momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def first_step(self):
        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                noise = torch.normal(0.0, 1 / (self.num_data * state['s']))

                p.add_(noise)

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['noisy_gradient'] = p.grad.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * group['rho'] * p.grad / state['s']

                p.add_(e_w)

    @torch.no_grad()
    def third_step(self):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            weight_decay = group['weight_decay']
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                momentum, s = state['momentum'], state['s']
                momentum.mul_(beta1).add_(p.grad * weight_decay, alpha=1.0 - beta1)

                var = (torch.sqrt(s).mul_(p.grad.abs()).add_(weight_decay + self.damping)).pow_(2)
                s.mul_(beta2).add_(var, alpha=1.0 - beta2)

                p.add_(momentum / s, alpha=-group['lr'])

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step()

        with torch.enable_grad():
            closure()

        self.second_step()

        with torch.enable_grad():
            loss = closure()

        self.third_step()

        return loss

CAME

Bases: BaseOptimizer

Confidence-guided Adaptive Memory Efficient Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0002
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999, 0.9999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
clip_threshold float

float. threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps1 float

float. term added to the denominator to improve numerical stability.

1e-30
eps2 float

float. term added to the denominator to improve numerical stability.

1e-16
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/came.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class CAME(BaseOptimizer):
    r"""Confidence-guided Adaptive Memory Efficient Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param clip_threshold: float. threshold of root-mean-square of final gradient update.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 2e-4,
        betas: BETAS = (0.9, 0.999, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-16,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.clip_threshold = clip_threshold
        self.eps1 = eps1
        self.eps2 = eps2
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps1': eps1,
            'eps2': eps2,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'CAME'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_shape: Tuple[int, ...] = grad.shape
            factored: bool = self.get_options(grad_shape)

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                    state['exp_avg_res_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_res_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                state['RMS'] = 0.0

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    @staticmethod
    def get_rms(x: torch.Tensor) -> float:
        r"""Get RMS."""
        return x.norm(2) / math.sqrt(x.numel())

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: torch.Tensor,
        exp_avg_sq_col: torch.Tensor,
        output: torch.Tensor,
    ):
        r"""Get approximation of EMA of squared gradient."""
        r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2, beta3 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                state['RMS'] = self.get_rms(p)

                update = torch.mul(grad, grad).add_(self.eps1)

                if factored:
                    exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1.0 - beta2)
                    exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1.0 - beta2)

                    self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(beta2).add_(update, alpha=1.0 - beta2)
                    torch.rsqrt(exp_avg_sq, out=update)

                if group['ams_bound']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat']
                    torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
                    torch.rsqrt(exp_avg_sq_hat / beta2, out=update)

                update.mul_(grad)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0))

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

                res = update - exp_avg
                res.pow_(2).add_(self.eps2)

                if factored:
                    exp_avg_res_row, exp_avg_res_col = state['exp_avg_res_row'], state['exp_avg_res_col']

                    exp_avg_res_row.mul_(beta3).add_(res.mean(dim=-1), alpha=1.0 - beta3)
                    exp_avg_res_col.mul_(beta3).add_(res.mean(dim=-2), alpha=1.0 - beta3)

                    self.approximate_sq_grad(exp_avg_res_row, exp_avg_res_col, update)
                    update.mul_(exp_avg)
                else:
                    update = exp_avg

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                update.mul_(group['lr'])

                p.add_(-update)

        return loss

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/optimizer/came.py
116
117
118
119
120
121
122
123
124
125
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: torch.Tensor,
    exp_avg_sq_col: torch.Tensor,
    output: torch.Tensor,
):
    r"""Get approximation of EMA of squared gradient."""
    r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
    c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
    torch.mul(r_factor, c_factor, out=output)

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/came.py
106
107
108
109
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/optimizer/came.py
111
112
113
114
@staticmethod
def get_rms(x: torch.Tensor) -> float:
    r"""Get RMS."""
    return x.norm(2) / math.sqrt(x.numel())

DAdaptAdaGrad

Bases: BaseOptimizer

AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
momentum float

float. momentum.

0.0
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

0.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
class DAdaptAdaGrad(BaseOptimizer):
    r"""AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        momentum: float = 0.0,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'k': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdaGrad'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'alpha_k' not in state:
                state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
                state['sk'] = torch.zeros_like(p)
                state['x0'] = torch.clone(p)
                if p.grad.is_sparse:
                    state['weighted_sk'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        g_sq = torch.tensor([0.0], device=device)
        sk_sq_weighted_change = torch.tensor([0.0], device=device)
        sk_l1_change = torch.tensor([0.0], device=device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_sq_weighted' not in group:
            group['sk_sq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_l1' not in group:
            group['sk_l1'] = torch.tensor([0.0], device=device)

        gsq_weighted = group['gsq_weighted']
        sk_sq_weighted = group['sk_sq_weighted']
        sk_l1 = group['sk_l1']

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                sk, alpha_k = state['sk'], state['alpha_k']

                if grad.is_sparse:
                    weighted_sk = state['weighted_sk']

                    grad = grad.coalesce()

                    vk = grad._values().pow(2)
                    sk_masked = sk.sparse_mask(grad).coalesce()
                    old_sk_l1_masked = sk_masked._values().abs().sum()

                    sk.add_(grad, alpha=d_lr)

                    sk_masked = sk.sparse_mask(grad).coalesce()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()
                    weighted_sk_masked = weighted_sk.sparse_mask(grad).coalesce()

                    # update alpha before step
                    alpha_k_p1_masked = alpha_k_masked._values() + vk

                    alpha_k_delta_masked = alpha_k_p1_masked - alpha_k_masked._values()
                    alpha_k_delta = torch.sparse_coo_tensor(grad.indices(), alpha_k_delta_masked, grad.shape)
                    alpha_k.add_(alpha_k_delta)

                    de_nom = torch.sqrt(alpha_k_p1_masked + eps)

                    grad_sq = vk.div(de_nom).sum()
                    g_sq.add_(grad_sq)

                    # update weighted sk sq tracking
                    weighted_sk_p1_masked = sk_masked._values().pow(2).div(de_nom)

                    sk_sq_weighted_change.add_(weighted_sk_p1_masked.sum() - weighted_sk_masked._values().sum())

                    weighted_sk_p1_delta_masked = weighted_sk_p1_masked - weighted_sk_masked._values()
                    weighted_sk_p1_delta = torch.sparse_coo_tensor(
                        grad.indices(), weighted_sk_p1_delta_masked, grad.shape
                    )
                    weighted_sk.add_(weighted_sk_p1_delta)

                    sk_l1_masked = sk_masked._values().abs().sum()
                    sk_l1_change.add_(sk_l1_masked - old_sk_l1_masked)
                else:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                    old_sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    old_sk_l1_param = sk.abs().sum()

                    alpha_k.add_(grad.pow(2))
                    grad_sq = grad.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    g_sq.add_(grad_sq)

                    sk.add_(grad, alpha=d_lr)

                    sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    sk_l1_param = sk.abs().sum()

                    sk_sq_weighted_change.add_(sk_sq_weighted_param - old_sk_sq_weighted_param)
                    sk_l1_change.add_(sk_l1_param - old_sk_l1_param)

        sk_sq_weighted.add_(sk_sq_weighted_change)
        gsq_weighted.add_(g_sq, alpha=d_lr ** 2)  # fmt: skip
        sk_l1.add_(sk_l1_change)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat = (sk_sq_weighted - gsq_weighted) / sk_l1
            d = group['d'] = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['gsq_weighted'] = gsq_weighted
            group['sk_sq_weighted'] = sk_sq_weighted
            group['sk_l1'] = sk_l1
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                alpha_k, sk, x0 = state['alpha_k'], state['sk'], state['x0']

                if grad.is_sparse:
                    grad = grad.coalesce()

                    sk_masked = sk.sparse_mask(grad).coalesce()._values()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()._values()
                    x0_masked = x0.sparse_mask(grad).coalesce()._values()
                    p_masked = p.sparse_mask(grad).coalesce()._values()

                    loc_masked = x0_masked - sk_masked.div(torch.sqrt(alpha_k_masked + group['eps']))

                    loc_delta_masked = loc_masked - p_masked
                    loc_delta = torch.sparse_coo_tensor(grad.indices(), loc_delta_masked, grad.shape)
                    p.add_(loc_delta)
                else:
                    z = x0 - sk.div(alpha_k.sqrt().add_(group['eps']))

                    if group['momentum'] > 0.0:
                        p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])
                    else:
                        p.copy_(z)

            group['k'] += 1

        return loss

DAdaptAdam

Bases: BaseOptimizer

Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. betas.

(0.9, 0.999)
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use AdamW style weight decay.

False
fixed_decay bool

bool. fix weight decay.

False
bias_correction bool

bool. Turn on Adam's bias correction.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
class DAdaptAdam(BaseOptimizer):
    r"""Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. betas.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use AdamW style weight decay.
    :param fixed_decay: bool. fix weight decay.
    :param bias_correction: bool. Turn on Adam's bias correction.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'step': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        beta1, beta2 = group['betas']

        beta2_sq: float = math.sqrt(beta2)

        d: float = group['d']
        lr: float = group['lr']

        bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
        bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
        bias_correction: float = bias_correction1 / bias_correction2_sq

        d_lr: float = self.apply_adam_debias(
            not group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
        )

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_acc = torch.tensor([0.0], device=device)

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, s = state['exp_avg'], state['exp_avg_sq'], state['s']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                numerator_acc.add_(torch.dot(grad.flatten(), s.div(de_nom).flatten()), alpha=d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                s.mul_(beta2_sq).add_(grad, alpha=d_lr * (1.0 - beta2_sq))

                sk_l1.add_(s.abs().sum())

        if sk_l1 == 0:
            return loss

        numerator_weighted.mul_(beta2_sq).add_(numerator_acc, alpha=1.0 - beta2_sq)  # fmt: skip

        if lr > 0.0:
            d_hat = numerator_weighted / (1.0 - beta2_sq) * sk_l1
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.addcdiv_(exp_avg, de_nom, value=-1.0)

        return loss

DAdaptSGD

Bases: BaseOptimizer

SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
momentum float

float. momentum.

0.9
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
class DAdaptSGD(BaseOptimizer):
    r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        momentum: float = 0.9,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptSGD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['s'] = torch.zeros_like(p)
                state['x0'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        sk_sq = torch.tensor([0.0], device=device)
        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        if group['step'] == 0:
            group['g0_norm'] = get_global_gradient_norm(self.param_groups).sqrt_().item()
        g0_norm = group['g0_norm']

        if g0_norm == 0:
            return loss

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr / g0_norm

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s = state['s']
                numerator_weighted.add_(torch.dot(grad.flatten(), s.flatten()), alpha=d_lr)

                s.add_(grad, alpha=d_lr)
                sk_sq.add_(s.pow(2).sum())

        if lr > 0.0:
            d_hat = 2.0 * numerator_weighted / sk_sq.sqrt()
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                z = state['z']
                z.copy_(state['x0'] - state['s'])

                p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])

        return loss

DAdaptAdan

Bases: BaseOptimizer

Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. decoupled weight decay.

False
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate. Default is inf, for unrestricted.

float('inf')
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
class DAdaptAdan(BaseOptimizer):
    r"""Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
        Default is inf, for unrestricted.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'd': d0,
            'growth_rate': growth_rate,
            'k': 0,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdan'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)
                state['previous_grad'] = -grad.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]

        beta1, beta2, beta3 = group['betas']
        growth_rate = group['growth_rate']

        d, lr = group['d'], group['lr']
        d_lr = float(d * lr)

        g_sq = torch.tensor([0.0], device=group['params'][0].device)
        sk_sq_weighted = torch.tensor([0.0], device=group['params'][0].device)
        sk_l1 = torch.tensor([0.0], device=group['params'][0].device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=group['params'][0].device)
        gsq_weighted = group['gsq_weighted']

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_diff = state['previous_grad']
                grad_diff.add_(grad)

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=d_lr * (1.0 - beta2))

                grad_diff.mul_(beta2).add_(grad)
                grad_diff = to_real(grad_diff * grad_diff.conj())
                exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

                grad_power = to_real(grad * grad.conj())
                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                g_sq.add_(grad_power.div_(de_nom).sum())

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=d_lr * (1.0 - beta3))

                sk_sq_weighted.add_(to_real(s * s.conj()).div_(de_nom).sum())
                sk_l1.add_(s.abs().sum())

                state['previous_grad'].copy_(-grad)

        if sk_l1 == 0:
            return loss

        gsq_weighted.mul_(beta3).add_(g_sq, alpha=(d_lr ** 2) * (1.0 - beta3))  # fmt: skip

        if lr > 0.0:
            d_hat = (sk_sq_weighted / (1.0 - beta3) - gsq_weighted) / sk_l1
            d = max(d, min(d_hat, d * growth_rate))

        for group in self.param_groups:
            group['step'] += 1

            group['gsq_weighted'] = gsq_weighted
            group['d'] = d
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if group['weight_decouple']:
                    p.mul_(1.0 - d_lr * group['weight_decay'])

                p.addcdiv_(exp_avg, de_nom, value=-1.0)
                p.addcdiv_(exp_avg_diff, de_nom, value=-beta2)

                if not group['weight_decouple']:
                    p.div_(1.0 + d_lr * group['weight_decay'])

            group['k'] += 1

        return loss

DAdaptLion

Bases: BaseOptimizer

Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
class DAdaptLion(BaseOptimizer):
    r"""Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        d0: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptLion'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_accumulator = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta2_sq = math.sqrt(beta2)

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        for group in self.param_groups:
            if group['step'] == 0:
                self.init_group(group)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, s = state['exp_avg'], state['s']

                update = exp_avg.clone().mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
                p.add_(update, alpha=-d_lr)

                exp_avg.mul_(beta2).add_(grad, alpha=(1.0 - beta2) * d_lr)

                numerator_accumulator.add_(torch.dot(update.flatten(), s.flatten()), alpha=d_lr)
                s.mul_(beta2_sq).add_(update, alpha=(1.0 - beta2_sq) * d_lr)

                sk_l1.add_(s.abs().sum())

        numerator_weighted.mul_(beta2_sq).add_(numerator_accumulator, alpha=1.0 - beta2_sq)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat: float = (numerator_weighted / ((1.0 - beta2_sq) * sk_l1)).item()
            d = max(d, d_hat)

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

        return loss

DeMo

Bases: SGD, BaseOptimizer

Decoupled Momentum Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
compression_decay float

float. compression_decay.

0.999
compression_top_k int

int. compression_top_k.

32
compression_chunk int

int. compression_chunk.

64
weight_decay float

float. weight decay (L2 penalty).

0.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/demo.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
class DeMo(torch.optim.SGD, BaseOptimizer):  # pragma: no cover
    r"""Decoupled Momentum Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param compression_decay: float. compression_decay.
    :param compression_top_k: int. compression_top_k.
    :param compression_chunk: int. compression_chunk.
    :param weight_decay: float. weight decay (L2 penalty).
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        compression_decay: float = 0.999,
        compression_top_k: int = 32,
        compression_chunk: int = 64,
        weight_decay: float = 0.0,
        process_group: Optional[ProcessGroup] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(compression_decay, 'compression_decay', 0.0, 1.0, range_type='[)')
        self.validate_positive(compression_top_k, 'compression_top_k')
        self.validate_positive(compression_chunk, 'compression_chunk')

        self.weight_decay = weight_decay

        self.compression_decay = compression_decay
        self.compression_top_k = compression_top_k
        self.compression_chunk = compression_chunk
        self.process_group = process_group

        self.data_transmit: int = 0
        self.data_receive: int = 0

        self.maximize = maximize

        super().__init__(
            params,
            lr=lr,
            foreach=False,
            momentum=0.0,
            dampening=0.0,
            nesterov=False,
            maximize=False,
            weight_decay=0.0,
            **kwargs,
        )

        self.demo_state = {}
        self.init_demo_states()
        self.init_parameters()

        self.default_dtype: torch.dtype = self.find_dtype()
        self.transform = TransformDCT(self.param_groups, self.compression_chunk, norm='ortho')
        self.compress = CompressDCT()

    def __str__(self) -> str:
        return 'DeMo'

    def find_dtype(self) -> torch.dtype:
        r"""Return dtype of the parameter."""
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    return p.dtype
        return torch.float32

    def init_demo_states(self) -> None:
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self.demo_state[p] = {}

    def init_parameters(self) -> None:
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.requires_grad:
                    state = self.demo_state.get(p, {})

                    state['delta'] = torch.zeros_like(p)

    def demo_all_gather(self, sparse_idx, sparse_val):
        world_size: int = get_world_size() if self.process_group is None else self.process_group.size()

        sparse_idx_list = [torch.zeros_like(sparse_idx) for _ in range(world_size)]
        sparse_val_list = [torch.zeros_like(sparse_val) for _ in range(world_size)]

        sparse_idx_handle = all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True)
        sparse_val_handle = all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True)

        sparse_idx_handle.wait()
        sparse_val_handle.wait()

        return sparse_idx_list, sparse_val_list

    @torch.no_grad()
    def init_group(self):
        pass

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        self.data_transmit = 0
        self.data_receive = 0

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.demo_state.get(p, {})

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=lr,
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                if self.compression_decay != 1:
                    state['delta'].mul_(self.compression_decay)

                state['delta'].add_(grad, alpha=lr)

                sparse_idx, sparse_val, x_shape = self.compress.compress(
                    self.transform.encode(state['delta']), self.compression_top_k
                )

                transmit_grad = self.transform.decode(self.compress.decompress(p, sparse_idx, sparse_val, x_shape))

                state['delta'].sub_(transmit_grad)

                sparse_idx_gather, sparse_val_gather = self.demo_all_gather(sparse_idx, sparse_val)

                self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes
                for si, v in zip(sparse_idx_gather, sparse_val_gather):
                    self.data_receive += si.nbytes + v.nbytes

                new_grad = self.transform.decode(
                    self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, x_shape)
                )

                if p.grad is None:
                    p.grad = new_grad
                else:
                    p.grad.copy_(new_grad)

                p.grad.sign_()

        return super().step(closure)

find_dtype()

Return dtype of the parameter.

Source code in pytorch_optimizer/optimizer/demo.py
348
349
350
351
352
353
354
def find_dtype(self) -> torch.dtype:
    r"""Return dtype of the parameter."""
    for group in self.param_groups:
        for p in group['params']:
            if p.requires_grad:
                return p.dtype
    return torch.float32

DiffGrad

Bases: BaseOptimizer

An Optimization Method for Convolutional Neural Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

bool. degenerated to SGD.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/diffgrad.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class DiffGrad(BaseOptimizer):
    r"""An Optimization Method for Convolutional Neural Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: bool. degenerated to SGD.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'diffGrad'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['previous_grad'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, previous_grad = state['exp_avg'], state['exp_avg_sq'], state['previous_grad']

                p, grad, exp_avg, exp_avg_sq, previous_grad = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, previous_grad
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                dfc = previous_grad.clone()
                dfc.sub_(grad).abs_().sigmoid_().mul_(exp_avg)
                state['previous_grad'].copy_(
                    torch.view_as_complex(grad) if torch.is_complex(state['previous_grad']) else grad
                )

                if not group['rectify']:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(dfc, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

EXAdam

Bases: BaseOptimizer

The Power of Adaptive Cross-Moments.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/exadam.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class EXAdam(BaseOptimizer):
    r"""The Power of Adaptive Cross-Moments.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        self.sq2: float = np.sqrt(2)

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'EXAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = group['lr'] * np.log(np.sqrt(group['step'] + 1) * self.sq2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                d1 = 1.0 + exp_avg_sq.div(exp_avg_sq.add(group['eps'])) * (1.0 - bias_correction2)

                exp_avg_p2 = exp_avg.pow(2)
                d2 = 1.0 + exp_avg_p2.div(exp_avg_p2.add(group['eps'])) * (1.0 - bias_correction1)

                m_tilde = exp_avg.div(bias_correction1) * d1
                v_tilde = exp_avg_sq.div(bias_correction2) * d2

                g_tilde = grad.div(bias_correction1) * d1

                update = (m_tilde + g_tilde) / v_tilde.sqrt().add_(group['eps'])

                p.add_(update, alpha=-step_size)

        return loss

DynamicLossScaler

Dynamically adjusts the loss scaling factor.

Dynamic loss scalers are important in mixed-precision training.
They help us avoid underflows and overflows in low-precision gradients.

See here for information:
<https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>

Shamelessly stolen and adapted from FairSeq.
<https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>

Reference : 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

Parameters:

Name Type Description Default
init_scale float

Initial loss scale.

2.0 ** 15
scale_factor float

Factor by which to increase or decrease loss scale.

2.0
scale_window int

If we do not experience overflow in scale_window iterations, loss scale will increase by scale_factor.

2000
tolerance float

Pct of iterations that have overflowed after which we must decrease the loss scale.

0.0
threshold Optional[float]

If not None, loss scale will decrease below this threshold.

None
Source code in pytorch_optimizer/optimizer/fp16.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class DynamicLossScaler:
    r"""Dynamically adjusts the loss scaling factor.

        Dynamic loss scalers are important in mixed-precision training.
        They help us avoid underflows and overflows in low-precision gradients.

        See here for information:
        <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>

        Shamelessly stolen and adapted from FairSeq.
        <https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>

        Reference : 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

    :param init_scale: Initial loss scale.
    :param scale_factor: Factor by which to increase or decrease loss scale.
    :param scale_window: If we do not experience overflow in scale_window iterations,
        loss scale will increase by scale_factor.
    :param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale.
    :param threshold: If not None, loss scale will decrease below this threshold.
    """

    def __init__(
        self,
        init_scale: float = 2.0 ** 15,
        scale_factor: float = 2.0,
        scale_window: int = 2000,
        tolerance: float = 0.00,
        threshold: Optional[float] = None,
    ):  # fmt: skip
        self.loss_scale = init_scale
        self.scale_factor = scale_factor
        self.scale_window = scale_window
        self.tolerance = tolerance
        self.threshold = threshold

        self.iter: int = 0
        self.last_overflow_iter: int = -1
        self.last_rescale_iter: int = -1
        self.overflows_since_rescale: int = 0
        self.has_overflow_serial: bool = False

    def update_scale(self, overflow: bool):
        r"""Update the loss scale.

            If overflow exceeds our tolerance, we decrease the loss scale.
            If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

        :param overflow: bool. adjust scales to prevent overflow.
        """
        iter_since_rescale: int = self.iter - self.last_rescale_iter

        if overflow:
            # calculate how often we overflowed already
            self.last_overflow_iter = self.iter
            self.overflows_since_rescale += 1

            pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
            if pct_overflow >= self.tolerance:
                # decrease loss scale by the scale factor
                self.decrease_loss_scale()

                # reset iterations
                self.last_rescale_iter = self.iter
                self.overflows_since_rescale = 0
        elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
            # increase the loss scale by scale factor
            self.loss_scale *= self.scale_factor
            self.last_rescale_iter = self.iter

        self.iter += 1

    def decrease_loss_scale(self):
        r"""Decrease the loss scale by self.scale_factor.

        NOTE: the loss_scale will not go below `self.threshold`.
        """
        self.loss_scale /= self.scale_factor
        if self.threshold is not None:
            self.loss_scale = max(self.loss_scale, self.threshold)

decrease_loss_scale()

Decrease the loss scale by self.scale_factor.

NOTE: the loss_scale will not go below self.threshold.

Source code in pytorch_optimizer/optimizer/fp16.py
83
84
85
86
87
88
89
90
def decrease_loss_scale(self):
    r"""Decrease the loss scale by self.scale_factor.

    NOTE: the loss_scale will not go below `self.threshold`.
    """
    self.loss_scale /= self.scale_factor
    if self.threshold is not None:
        self.loss_scale = max(self.loss_scale, self.threshold)

update_scale(overflow)

Update the loss scale.

If overflow exceeds our tolerance, we decrease the loss scale.
If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

Parameters:

Name Type Description Default
overflow bool

bool. adjust scales to prevent overflow.

required
Source code in pytorch_optimizer/optimizer/fp16.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def update_scale(self, overflow: bool):
    r"""Update the loss scale.

        If overflow exceeds our tolerance, we decrease the loss scale.
        If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

    :param overflow: bool. adjust scales to prevent overflow.
    """
    iter_since_rescale: int = self.iter - self.last_rescale_iter

    if overflow:
        # calculate how often we overflowed already
        self.last_overflow_iter = self.iter
        self.overflows_since_rescale += 1

        pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
        if pct_overflow >= self.tolerance:
            # decrease loss scale by the scale factor
            self.decrease_loss_scale()

            # reset iterations
            self.last_rescale_iter = self.iter
            self.overflows_since_rescale = 0
    elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
        # increase the loss scale by scale factor
        self.loss_scale *= self.scale_factor
        self.last_rescale_iter = self.iter

    self.iter += 1

SafeFP16Optimizer

Bases: Optimizer

Safe FP16 Optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

OPTIMIZER.

required
aggregate_g_norms bool

bool. aggregate_g_norms.

False
min_loss_scale float

float. min_loss_scale.

2 ** -5
Source code in pytorch_optimizer/optimizer/fp16.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
class SafeFP16Optimizer(Optimizer):  # pragma: no cover
    r"""Safe FP16 Optimizer.

    :param optimizer: OPTIMIZER.
    :param aggregate_g_norms: bool. aggregate_g_norms.
    :param min_loss_scale: float. min_loss_scale.
    """

    def __init__(
        self,
        optimizer: Optimizer,
        aggregate_g_norms: bool = False,
        min_loss_scale: float = 2 ** -5,
    ):  # fmt: skip
        self.optimizer = optimizer
        self.aggregate_g_norms = aggregate_g_norms
        self.min_loss_scale = min_loss_scale

        self.fp16_params = self.get_parameters(optimizer)
        self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False)

        # we want the optimizer to be tracking the fp32 parameters
        if len(optimizer.param_groups) != 1:
            # future implementers: this should hopefully be a matter of just iterating through the param groups and
            # keeping track of the pointer through the fp32_params
            raise NotImplementedError('Need to implement the parameter group transfer.')

        optimizer.param_groups[0]['params'] = self.fp32_params

        self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15)  # fmt: skip
        self.needs_sync: bool = True

    @classmethod
    def get_parameters(cls, optimizer: Optimizer) -> List:
        params: List = []
        for group in optimizer.param_groups:
            params += list(group['params'])
        return params

    @classmethod
    def build_fp32_params(
        cls, parameters: PARAMETERS, flatten: bool = True
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        if flatten:
            total_param_size: int = sum(p.numel() for p in parameters)
            fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)

            offset: int = 0
            for p in parameters:
                p_num_el = p.numel()
                fp32_params[offset:offset + p_num_el].copy_(p.view(-1))  # fmt: skip
                offset += p_num_el

            fp32_params = nn.Parameter(fp32_params)
            fp32_params.grad = fp32_params.new(total_param_size)

            return fp32_params

        fp32_params: List = []
        for p in parameters:
            p32 = nn.Parameter(p.float())
            p32.grad = torch.zeros_like(p32)
            fp32_params.append(p32)

        return fp32_params

    def state_dict(self) -> Dict:
        r"""Return the optimizer state dict."""
        state_dict = self.optimizer.state_dict()
        if self.scaler is not None:
            state_dict['loss_scaler'] = self.scaler.loss_scale
        return state_dict

    def load_state_dict(self, state_dict: Dict):
        r"""Load an optimizer state dict.

            In general, we should prefer the configuration of the existing optimizer instance (e.g., learning rate)
            over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of
            optimizer args.

        :param state_dict: Dict. state_dict.
        """
        if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
            self.scaler.loss_scale = state_dict['loss_scaler']
        self.optimizer.load_state_dict(state_dict)

    def backward(self, loss, update_main_grads: bool = False):
        r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.

            Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
            additionally dynamically scales the loss to avoid gradient underflow.

        :param loss: float. loss.
        :param update_main_grads: bool. update main gradient.
        """
        if self.scaler is not None:
            loss = loss * self.scaler.loss_scale

        loss.backward()

        self.needs_sync = True
        if update_main_grads:
            self.update_main_grads()

    def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0) -> None:
        r"""Sync fp16 to fp32 gradients."""
        if self.needs_sync:
            if self.scaler is not None:
                multiply_grads /= self.scaler.loss_scale

            for p16, p32 in zip(self.fp16_params, self.fp32_params):
                if not p16.requires_grad:
                    continue

                if p16.grad is not None:
                    p32.grad.copy_(p16.grad)
                    p32.grad.mul_(multiply_grads)
                else:
                    p32.grad = torch.zeros_like(p16, dtype=torch.float)

            self.needs_sync = False

    def multiply_grads(self, c: float) -> None:
        r"""Multiply grads by a constant c."""
        if self.needs_sync:
            self.sync_fp16_grads_to_fp32(c)
            return

        for p32 in self.fp32_params:
            p32.grad.mul_(c)

    def update_main_grads(self) -> None:
        self.sync_fp16_grads_to_fp32()

    def clip_main_grads(self, max_norm: float):
        r"""Clip gradient norm and updates dynamic loss scaler."""
        self.sync_fp16_grads_to_fp32()

        grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

        # detect overflow and adjust loss scale
        if self.scaler is not None:
            overflow: bool = has_overflow(grad_norm)
            prev_scale: float = self.scaler.loss_scale

            self.scaler.update_scale(overflow)

            if overflow:
                self.zero_grad()
                if self.scaler.loss_scale <= self.min_loss_scale:
                    # Use FloatingPointError as an uncommon error
                    # that parent functions can safely catch to stop training.
                    self.scaler.loss_scale = prev_scale

                    raise FloatingPointError(
                        f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                        'Try lowering the learning rate, using gradient clipping or increasing the batch size.\n'
                        f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                    )

        return grad_norm

    def step(self, closure: CLOSURE = None):
        r"""Perform a single optimization step."""
        self.sync_fp16_grads_to_fp32()
        self.optimizer.step(closure)

        for p16, p32 in zip(self.fp16_params, self.fp32_params):
            if not p16.requires_grad:
                continue
            p16.data.copy_(p32)

    def zero_grad(self) -> None:
        r"""Clear the gradients of all optimized parameters."""
        for p16 in self.fp16_params:
            p16.grad = None
        for p32 in self.fp32_params:
            p32.grad.zero_()
        self.needs_sync = False

    def get_lr(self) -> float:
        r"""Get learning rate."""
        return self.optimizer.get_lr()

    def set_lr(self, lr: float):
        r"""Set learning rate."""
        self.optimizer.set_lr(lr)

    @property
    def loss_scale(self) -> float:
        r"""Convenience function which TorchAgent calls to get current scale value."""
        return self.scaler.loss_scale

loss_scale: float property

Convenience function which TorchAgent calls to get current scale value.

backward(loss, update_main_grads=False)

Compute the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
additionally dynamically scales the loss to avoid gradient underflow.

Parameters:

Name Type Description Default
loss

float. loss.

required
update_main_grads bool

bool. update main gradient.

False
Source code in pytorch_optimizer/optimizer/fp16.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def backward(self, loss, update_main_grads: bool = False):
    r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.

        Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
        additionally dynamically scales the loss to avoid gradient underflow.

    :param loss: float. loss.
    :param update_main_grads: bool. update main gradient.
    """
    if self.scaler is not None:
        loss = loss * self.scaler.loss_scale

    loss.backward()

    self.needs_sync = True
    if update_main_grads:
        self.update_main_grads()

clip_main_grads(max_norm)

Clip gradient norm and updates dynamic loss scaler.

Source code in pytorch_optimizer/optimizer/fp16.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def clip_main_grads(self, max_norm: float):
    r"""Clip gradient norm and updates dynamic loss scaler."""
    self.sync_fp16_grads_to_fp32()

    grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

    # detect overflow and adjust loss scale
    if self.scaler is not None:
        overflow: bool = has_overflow(grad_norm)
        prev_scale: float = self.scaler.loss_scale

        self.scaler.update_scale(overflow)

        if overflow:
            self.zero_grad()
            if self.scaler.loss_scale <= self.min_loss_scale:
                # Use FloatingPointError as an uncommon error
                # that parent functions can safely catch to stop training.
                self.scaler.loss_scale = prev_scale

                raise FloatingPointError(
                    f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or increasing the batch size.\n'
                    f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                )

    return grad_norm

get_lr()

Get learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
273
274
275
def get_lr(self) -> float:
    r"""Get learning rate."""
    return self.optimizer.get_lr()

load_state_dict(state_dict)

Load an optimizer state dict.

In general, we should prefer the configuration of the existing optimizer instance (e.g., learning rate)
over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of
optimizer args.

Parameters:

Name Type Description Default
state_dict Dict

Dict. state_dict.

required
Source code in pytorch_optimizer/optimizer/fp16.py
166
167
168
169
170
171
172
173
174
175
176
177
def load_state_dict(self, state_dict: Dict):
    r"""Load an optimizer state dict.

        In general, we should prefer the configuration of the existing optimizer instance (e.g., learning rate)
        over that found in the state_dict. This allows us to resume training from a checkpoint using a new set of
        optimizer args.

    :param state_dict: Dict. state_dict.
    """
    if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
        self.scaler.loss_scale = state_dict['loss_scaler']
    self.optimizer.load_state_dict(state_dict)

multiply_grads(c)

Multiply grads by a constant c.

Source code in pytorch_optimizer/optimizer/fp16.py
215
216
217
218
219
220
221
222
def multiply_grads(self, c: float) -> None:
    r"""Multiply grads by a constant c."""
    if self.needs_sync:
        self.sync_fp16_grads_to_fp32(c)
        return

    for p32 in self.fp32_params:
        p32.grad.mul_(c)

set_lr(lr)

Set learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
277
278
279
def set_lr(self, lr: float):
    r"""Set learning rate."""
    self.optimizer.set_lr(lr)

state_dict()

Return the optimizer state dict.

Source code in pytorch_optimizer/optimizer/fp16.py
159
160
161
162
163
164
def state_dict(self) -> Dict:
    r"""Return the optimizer state dict."""
    state_dict = self.optimizer.state_dict()
    if self.scaler is not None:
        state_dict['loss_scaler'] = self.scaler.loss_scale
    return state_dict

step(closure=None)

Perform a single optimization step.

Source code in pytorch_optimizer/optimizer/fp16.py
255
256
257
258
259
260
261
262
263
def step(self, closure: CLOSURE = None):
    r"""Perform a single optimization step."""
    self.sync_fp16_grads_to_fp32()
    self.optimizer.step(closure)

    for p16, p32 in zip(self.fp16_params, self.fp32_params):
        if not p16.requires_grad:
            continue
        p16.data.copy_(p32)

sync_fp16_grads_to_fp32(multiply_grads=1.0)

Sync fp16 to fp32 gradients.

Source code in pytorch_optimizer/optimizer/fp16.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0) -> None:
    r"""Sync fp16 to fp32 gradients."""
    if self.needs_sync:
        if self.scaler is not None:
            multiply_grads /= self.scaler.loss_scale

        for p16, p32 in zip(self.fp16_params, self.fp32_params):
            if not p16.requires_grad:
                continue

            if p16.grad is not None:
                p32.grad.copy_(p16.grad)
                p32.grad.mul_(multiply_grads)
            else:
                p32.grad = torch.zeros_like(p16, dtype=torch.float)

        self.needs_sync = False

zero_grad()

Clear the gradients of all optimized parameters.

Source code in pytorch_optimizer/optimizer/fp16.py
265
266
267
268
269
270
271
def zero_grad(self) -> None:
    r"""Clear the gradients of all optimized parameters."""
    for p16 in self.fp16_params:
        p16.grad = None
    for p32 in self.fp32_params:
        p32.grad.zero_()
    self.needs_sync = False

FAdam

Bases: BaseOptimizer

Adam is a natural gradient optimizer using diagonal empirical Fisher information.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.1
clip float

float. maximum norm of the gradient.

1.0
p float

float. momentum factor.

0.5
eps float

float. term added to the denominator to improve numerical stability.

1e-08
momentum_dtype dtype

torch.dtype. dtype of momentum.

float32
fim_dtype dtype

torch.dtype. dtype of fim.

float32
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fadam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class FAdam(BaseOptimizer):
    r"""Adam is a natural gradient optimizer using diagonal empirical Fisher information.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param clip: float. maximum norm of the gradient.
    :param p: float. momentum factor.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param momentum_dtype: torch.dtype. dtype of momentum.
    :param fim_dtype: torch.dtype. dtype of fim.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.1,
        clip: float = 1.0,
        p: float = 0.5,
        eps: float = 1e-8,
        momentum_dtype: torch.dtype = torch.float32,
        fim_dtype: torch.dtype = torch.float32,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(clip, 'clip')
        self.validate_positive(p, 'p')
        self.validate_non_negative(eps, 'eps')

        self.momentum_dtype = momentum_dtype
        self.fim_dtype = fim_dtype
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'clip': clip,
            'p': p,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(p, dtype=self.momentum_dtype)
                state['fim'] = torch.zeros_like(p, dtype=self.fim_dtype)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            curr_beta2: float = self.debias_beta(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                momentum, fim = state['momentum'], state['fim']

                fim.mul_(curr_beta2).addcmul_(grad, grad, value=1.0 - curr_beta2)

                rms_grad = grad.pow(2).mean().sqrt_()
                curr_eps = min(rms_grad, 1) * group['eps']

                fim_base = fim.pow(group['p']).add_(curr_eps)
                grad_nat = grad / fim_base

                rms = grad_nat.pow(2).mean().sqrt_()
                divisor = max(1, rms) / group['clip']
                grad_nat.div_(divisor)

                momentum.mul_(beta1).add_(grad_nat, alpha=1.0 - beta1)

                grad_weights = p / fim_base

                rms = torch.pow(grad_weights, 2).mean().sqrt_()
                divisor = max(1, rms) / group['clip']
                grad_weights.div_(divisor)

                grad_weights.mul_(group['weight_decay']).add_(momentum)

                p.add_(grad_weights, alpha=-group['lr'])

        return loss

Fira

Bases: BaseOptimizer

Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint? Fira with AdamW optimizer.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fira.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class Fira(BaseOptimizer):
    r"""Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint? Fira with AdamW optimizer.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Fira'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if 'rank' in group and p.dim() == 2:
                    if 'projector' not in state:
                        state['projector'] = GaLoreProjector(
                            rank=group['rank'],
                            update_proj_gap=group['update_proj_gap'],
                            scale=group['scale'],
                            projection_type=group['projection_type'],
                        )

                    grad = state['projector'].project(grad, group['step'])

                if 'exp_avg' not in state:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                norm_grad = exp_avg / de_nom

                if 'rank' in group and p.dim() == 2:
                    sub_grad = state['projector'].project_back(grad)

                    norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1

                    scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
                    if norm_dim == 1:
                        scaling_factor = scaling_factor.unsqueeze(1)

                    scaling_grad = grad.sub(sub_grad).mul_(scaling_factor)

                    if 'scaling_grad' in state:
                        scaling_grad_norm = torch.norm(scaling_grad)

                        limiter = max(scaling_grad_norm / (state['scaling_grad'] + 1e-8), 1.01) / 1.01
                        scaling_grad.div_(limiter)

                        state['scaling_grad'] = scaling_grad_norm / limiter
                    else:
                        state['scaling_grad'] = torch.norm(scaling_grad)

                    norm_grad = state['projector'].project_back(norm_grad).add_(scaling_grad)

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        return loss

FOCUS

Bases: BaseOptimizer

First Order Concentrated Updating Scheme.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
gamma float

float. control the strength of the attraction.

0.1
weight_decay float

float. weight decay (L2 penalty).

0.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/focus.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class FOCUS(BaseOptimizer):
    r"""First Order Concentrated Updating Scheme.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param gamma: float. control the strength of the attraction.
    :param weight_decay: float. weight decay (L2 penalty).
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        betas: BETAS = (0.9, 0.999),
        gamma: float = 0.1,
        weight_decay: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(gamma, 'gamma', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'gamma': gamma, 'weight_decay': weight_decay}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FOCUS'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['pbar'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            weight_decay: float = group['weight_decay']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, pbar = state['exp_avg'], state['pbar']

                p, grad, exp_avg, pbar = self.view_as_real(p, grad, exp_avg, pbar)

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                pbar.mul_(beta2).add_(p, alpha=1.0 - beta2)

                pbar_hat = pbar / bias_correction2

                if weight_decay > 0.0:
                    p.add_(pbar_hat, alpha=-group['lr'] * weight_decay)

                update = (p - pbar_hat).sign_().mul_(group['gamma']).add_(torch.sign(exp_avg))

                p.add_(update, alpha=-group['lr'])

        return loss

Fromage

Bases: BaseOptimizer

On the distance between two neural networks and the stability of learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
p_bound Optional[float]

Optional[float]. Restricts the optimisation to a bounded set. A value of 2.0 restricts parameter norms to lie within 2x their initial norms. This regularises the model class.

None
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/fromage.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class Fromage(BaseOptimizer):
    r"""On the distance between two neural networks and the stability of learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param p_bound: Optional[float]. Restricts the optimisation to a bounded set. A value of 2.0 restricts parameter
        norms to lie within 2x their initial norms. This regularises the model class.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self, params: PARAMETERS, lr: float = 1e-2, p_bound: Optional[float] = None, maximize: bool = False, **kwargs
    ):
        self.validate_learning_rate(lr)

        self.p_bound = p_bound
        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Fromage'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and self.p_bound is not None:
                state['max'] = p.norm().mul_(self.p_bound)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            pre_factor: float = math.sqrt(1 + group['lr'] ** 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p, grad = self.view_as_real(p, grad)

                p_norm, g_norm = p.norm(), grad.norm()

                if p_norm > 0.0 and g_norm > 0.0:
                    p.add_(grad * (p_norm / g_norm), alpha=-group['lr'])
                else:
                    p.add_(grad, alpha=-group['lr'])

                p.div_(pre_factor)

                if self.p_bound is not None:
                    p_norm = p.norm()
                    if p_norm > state['max']:
                        p.mul_(state['max']).div_(p_norm)

        return loss

FTRL

Bases: BaseOptimizer

Follow The Regularized Leader.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
lr_power float

float. controls how the learning rate decreases during training. use zero for a fixed learning rate.

-0.5
beta float

float. beta value in the paper.

0.0
lambda_1 float

float. L1 regularization parameter.

0.0
lambda_2 float

float. L2 regularization parameter.

0.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ftrl.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class FTRL(BaseOptimizer):
    r"""Follow The Regularized Leader.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param lr_power: float. controls how the learning rate decreases during training. use zero for a fixed learning
        rate.
    :param beta: float. beta value in the paper.
    :param lambda_1: float. L1 regularization parameter.
    :param lambda_2: float. L2 regularization parameter.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        lr_power: float = -0.5,
        beta: float = 0.0,
        lambda_1: float = 0.0,
        lambda_2: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(beta, 'beta')
        self.validate_non_positive(lr_power, 'lr_power')
        self.validate_non_negative(lambda_1, 'lambda_1')
        self.validate_non_negative(lambda_2, 'lambda_2')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'lr_power': lr_power, 'beta': beta, 'lambda_1': lambda_1, 'lambda_2': lambda_2}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'FTRL'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = torch.zeros_like(p)
                state['n'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, n = state['z'], state['n']

                p, grad, z, n = self.view_as_real(p, grad, z, n)

                grad_p2 = grad.pow(2)

                sigma = (n + grad_p2).pow_(-group['lr_power']).sub_(n.pow(-group['lr_power'])).div_(group['lr'])

                z.add_(grad).sub_(sigma.mul(p))
                n.add_(grad_p2)

                update = z.sign().mul_(group['lambda_1']).sub_(z)
                update.div_((group['beta'] + n.sqrt()).div_(group['lr']).add_(group['lambda_2']))

                p.copy_(update)
                p.masked_fill_(z.abs() < group['lambda_1'], 0.0)

        return loss

centralize_gradient(grad, gc_conv_only=False)

Gradient Centralization (GC).

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. gradient.

required
gc_conv_only bool

bool. 'False' for both conv & fc layers.

False
Source code in pytorch_optimizer/optimizer/gradient_centralization.py
 4
 5
 6
 7
 8
 9
10
11
12
def centralize_gradient(grad: torch.Tensor, gc_conv_only: bool = False) -> None:
    r"""Gradient Centralization (GC).

    :param grad: torch.Tensor. gradient.
    :param gc_conv_only: bool. 'False' for both conv & fc layers.
    """
    size: int = grad.dim()
    if (gc_conv_only and size > 3) or (not gc_conv_only and size > 1):
        grad.add_(-grad.mean(dim=tuple(range(1, size)), keepdim=True))

Grams

Bases: BaseOptimizer

Gradient Descent with Adaptive Momentum Scaling.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. decoupled weight decay.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/grams.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class Grams(BaseOptimizer):
    r"""Gradient Descent with Adaptive Momentum Scaling.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Grams'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.lerp_(grad, weight=beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                update = (exp_avg / bias_correction1) / (exp_avg_sq / bias_correction2_sq).sqrt_().add_(group['eps'])
                update.abs_().mul_(grad.sign())

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(update, alpha=-group['lr'])

        return loss

Gravity

Bases: BaseOptimizer

a Kinematic Approach on Optimization in Deep Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
alpha float

float. alpha controls the V initialization.

0.01
beta float

float. beta will be used to compute running average of V.

0.9
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/gravity.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class Gravity(BaseOptimizer):
    r"""a Kinematic Approach on Optimization in Deep Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param alpha: float. alpha controls the V initialization.
    :param beta: float. beta will be used to compute running average of V.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        alpha: float = 0.01,
        beta: float = 0.9,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'alpha': alpha, 'beta': beta}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Gravity'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['v'] = torch.empty_like(p).normal_(mean=0.0, std=group['alpha'] / group['lr'])

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta_t: float = (group['beta'] * group['step'] + 1) / (group['step'] + 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                v = state['v']

                p, grad, v = self.view_as_real(p, grad, v)

                m = 1.0 / grad.abs().max()
                zeta = grad / (1.0 + (grad / m) ** 2)

                v.mul_(beta_t).add_(zeta, alpha=1.0 - beta_t)

                p.add_(v, alpha=-group['lr'])

        return loss

GrokFastAdamW

Bases: BaseOptimizer

Accelerated Grokking by Amplifying Slow Gradients with AdamW.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
grokfast bool

bool. whether to use grokfast.

True
grokfast_alpha float

float. momentum hyperparameter of the EMA.

0.98
grokfast_lamb float

float. amplifying factor hyperparameter of the filter.

2.0
grokfast_after_step int

int. warmup step for grokfast.

0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/grokfast.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
class GrokFastAdamW(BaseOptimizer):
    r"""Accelerated Grokking by Amplifying Slow Gradients with AdamW.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param grokfast: bool. whether to use grokfast.
    :param grokfast_alpha: float. momentum hyperparameter of the EMA.
    :param grokfast_lamb: float. amplifying factor hyperparameter of the filter.
    :param grokfast_after_step: int. warmup step for grokfast.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        betas: BETAS = (0.9, 0.99),
        grokfast: bool = True,
        grokfast_alpha: float = 0.98,
        grokfast_lamb: float = 2.0,
        grokfast_after_step: int = 0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        normalize_lr: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(grokfast_alpha, 'grokfast_alpha', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        if grokfast and normalize_lr:
            lr /= 1.0 + grokfast_lamb

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'grokfast': grokfast,
            'grokfast_alpha': grokfast_alpha,
            'grokfast_lamb': grokfast_lamb,
            'grokfast_after_step': grokfast_after_step,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'GrokFastAdamW'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['grokfast'] and group['grokfast_lamb'] > 0.0:
                    state['grok_exp_avg'] = grad.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            should_grokfast: bool = (
                group['grokfast'] and group['step'] > group['grokfast_after_step'] and group['grokfast_lamb'] > 0.0
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, grok_exp_avg = (
                    state['exp_avg'],
                    state['exp_avg_sq'],
                    state.get('grok_exp_avg', None),
                )

                p, grad, exp_avg, exp_avg_sq, grok_exp_avg = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, grok_exp_avg
                )

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if should_grokfast:
                    grok_exp_avg.lerp_(grad, weight=1.0 - group['grokfast_alpha'])
                    grad.add_(grok_exp_avg, alpha=group['grokfast_lamb'])

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).clamp_(min=group['eps'])

                update = exp_avg.div(bias_correction1).div_(de_nom)

                p.add_(update, alpha=-group['lr'])

        return loss

GSAM

Bases: BaseOptimizer

Surrogate Gap Guided Sharpness-Aware Minimization.

Example:

Here's an example::

    model = YourModel()
    base_optimizer = AdamP(model.parameters())
    lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
    rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
    optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

    def loss_fn(predictions, targets):
        return F.cross_entropy(predictions, targets)

    for inputs, targets in data:
        optimizer.set_closure(loss_fn, inputs, targets)
        predictions, loss = optimizer.step()
        lr_scheduler.step()
        optimizer.update_rho_t()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer Optimizer

Optimizer. base optimizer.

required
model Module

nn.Module. model.

required
alpha float

float. rho alpha.

0.4
rho_scheduler

rho scheduler.

required
adaptive bool

bool. element-wise Adaptive SAM.

False
perturb_eps float

float. epsilon for perturbation.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class GSAM(BaseOptimizer):  # pragma: no cover
    r"""Surrogate Gap Guided Sharpness-Aware Minimization.

    Example:
    -------
        Here's an example::

            model = YourModel()
            base_optimizer = AdamP(model.parameters())
            lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
            rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
            optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

            def loss_fn(predictions, targets):
                return F.cross_entropy(predictions, targets)

            for inputs, targets in data:
                optimizer.set_closure(loss_fn, inputs, targets)
                predictions, loss = optimizer.step()
                lr_scheduler.step()
                optimizer.update_rho_t()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: Optimizer. base optimizer.
    :param model: nn.Module. model.
    :param alpha: float. rho alpha.
    :param rho_scheduler: rho scheduler.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param perturb_eps: float. epsilon for perturbation.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        base_optimizer: Optimizer,
        model: nn.Module,
        rho_scheduler,
        alpha: float = 0.4,
        adaptive: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_range(alpha, 'alpha', 0.0, 1.0)

        self.model = model
        self.rho_scheduler = rho_scheduler
        self.alpha = alpha
        self.adaptive = adaptive
        self.perturb_eps = perturb_eps

        self.rho_t: float = 0.0
        self.forward_backward_func: Optional[Callable] = None

        if hasattr(ReduceOp, 'AVG'):
            self.grad_reduce = ReduceOp.AVG
            self.manual_average: bool = False
        else:
            self.grad_reduce = ReduceOp.SUM
            self.manual_average: bool = True

        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

        defaults: DEFAULTS = {'adaptive': adaptive, **kwargs}

        super().__init__(params, defaults)

        self.update_rho_t()

    def __str__(self) -> str:
        return 'GSAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def update_rho_t(self) -> float:
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho: float):
        grad_norm = self.grad_norm(weight_adaptive=self.adaptive)
        for group in self.param_groups:
            scale = rho / (grad_norm + self.perturb_eps)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['old_g'] = p.grad.clone()

                e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)

                p.add_(e_w)

                self.state[p]['e_w'] = e_w

    @torch.no_grad()
    def un_perturb(self):
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p]:
                    p.sub_(self.state[p]['e_w'])

    @torch.no_grad()
    def gradient_decompose(self, alpha: float = 0.0):
        inner_prod = 0.0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                inner_prod += torch.sum(self.state[p]['old_g'] * p.grad)

        new_grad_norm = self.grad_norm(by=None)
        old_grad_norm = self.grad_norm(by='old_g')

        cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / (
                    new_grad_norm + self.perturb_eps
                )
                p.grad.add_(vertical, alpha=-alpha)

    @torch.no_grad()
    def sync_grad(self):
        if is_initialized():
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    all_reduce(p.grad, op=self.grad_reduce)
                    if self.manual_average:
                        p.grad.div_(float(get_world_size()))

    @torch.no_grad()
    def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor:
        return torch.norm(
            torch.stack(
                [
                    ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2)
                    for group in self.param_groups
                    for p in group['params']
                    if p.grad is not None
                ]
            ),
            p=2,
        )

    def maybe_no_sync(self):
        return self.model.no_sync() if is_initialized() else ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs):
        r"""Set closure.

            Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
            automatically performs forward and backward passes. This function does not take any arguments,
            and the inputs and targets data should be pre-set in the definition of partial-function.

        :param loss_fn: nn.Module. loss function.
        :param inputs: torch.Tensor. inputs.
        :param targets: torch.Tensor. targets.
        """

        def get_grad():
            self.base_optimizer.zero_grad()
            with torch.enable_grad():
                outputs = self.model(inputs)
                loss = loss_fn(outputs, targets, **kwargs)

            loss.backward()

            return outputs, loss.detach()

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> Tuple[torch.Tensor, float]:
        get_grad = closure if closure else self.forward_backward_func

        with self.maybe_no_sync():
            outputs, loss = get_grad()

            self.perturb_weights(rho=self.rho_t)

            disable_running_stats(self.model)

            get_grad()

            self.gradient_decompose(self.alpha)

            self.un_perturb()

        self.sync_grad()

        self.base_optimizer.step()

        enable_running_stats(self.model)

        return outputs, loss

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

set_closure(loss_fn, inputs, targets, **kwargs)

Set closure.

Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
automatically performs forward and backward passes. This function does not take any arguments,
and the inputs and targets data should be pre-set in the definition of partial-function.

Parameters:

Name Type Description Default
loss_fn Module

nn.Module. loss function.

required
inputs Tensor

torch.Tensor. inputs.

required
targets Tensor

torch.Tensor. targets.

required
Source code in pytorch_optimizer/optimizer/sam.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
@torch.no_grad()
def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs):
    r"""Set closure.

        Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
        automatically performs forward and backward passes. This function does not take any arguments,
        and the inputs and targets data should be pre-set in the definition of partial-function.

    :param loss_fn: nn.Module. loss function.
    :param inputs: torch.Tensor. inputs.
    :param targets: torch.Tensor. targets.
    """

    def get_grad():
        self.base_optimizer.zero_grad()
        with torch.enable_grad():
            outputs = self.model(inputs)
            loss = loss_fn(outputs, targets, **kwargs)

        loss.backward()

        return outputs, loss.detach()

    self.forward_backward_func = get_grad

Kate

Bases: BaseOptimizer

Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
delta float

float. delta. 0.0 or 1e-8.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. epsilon value.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/kate.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class Kate(BaseOptimizer):
    r"""Remove that Square Root: A New Efficient Scale-Invariant Version of AdaGrad.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param delta: float. delta. 0.0 or 1e-8.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. epsilon value.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        delta: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(delta, 'delta', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'delta': delta,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Kate'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['m'] = torch.zeros_like(p)
                state['b'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                m, b = state['m'], state['b']

                p, grad, m, b = self.view_as_real(p, grad, m, b)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                grad_p2 = grad.pow(2)

                b.mul_(b).add_(grad_p2).add_(group['eps'])
                m.mul_(m).add_(grad_p2, alpha=group['delta']).add_(grad_p2 / b).sqrt_()

                update = m.mul(grad).div_(b)

                p.add_(update, alpha=-group['lr'])

                b.sqrt_()

        return loss

Lamb

Bases: BaseOptimizer

Large Batch Optimization for Deep Learning.

This Lamb implementation is based on the paper v3, which does not use de-biasing.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
degenerated_to_sgd bool

bool. degenerated to SGD.

False
n_sma_threshold int

int. (recommended is 5).

5
grad_averaging bool

bool. whether apply (1 - beta2) to gradient when calculating running averages of gradient.

True
max_grad_norm float

float. max gradient norm to clip.

1.0
adam bool

bool. always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.

False
pre_norm bool

bool. perform pre-normalization of all gradients.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lamb.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class Lamb(BaseOptimizer):
    r"""Large Batch Optimization for Deep Learning.

        This Lamb implementation is based on the paper v3, which does not use de-biasing.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param degenerated_to_sgd: bool. degenerated to SGD.
    :param n_sma_threshold: int. (recommended is 5).
    :param grad_averaging: bool. whether apply (1 - beta2) to gradient when calculating running averages of gradient.
    :param max_grad_norm: float. max gradient norm to clip.
    :param adam: bool. always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.
    :param pre_norm: bool. perform pre-normalization of all gradients.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    clamp: float = 10.0

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        degenerated_to_sgd: bool = False,
        n_sma_threshold: int = 5,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        adam: bool = False,
        pre_norm: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.degenerated_to_sgd = degenerated_to_sgd
        self.n_sma_threshold = n_sma_threshold
        self.pre_norm = pre_norm
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'grad_averaging': grad_averaging,
            'max_grad_norm': max_grad_norm,
            'adam': adam,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lamb'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        grad_norm = 1.0
        if self.pre_norm:
            grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            beta3: float = 1.0 - beta1 if group['grad_averaging'] else 1.0
            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                if self.pre_norm:
                    grad.div_(grad_norm)

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=beta3)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['rectify']:
                    update = p.clone()
                    if n_sma >= self.n_sma_threshold:
                        de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                        update.addcdiv_(exp_avg, de_nom, value=-step_size)
                    else:
                        update.add_(exp_avg, alpha=-step_size)
                else:
                    update = exp_avg / exp_avg_sq.sqrt().add_(group['eps'])

                weight_norm = torch.linalg.norm(p).clamp_(min=0, max=self.clamp)
                p_norm = torch.linalg.norm(update)
                trust_ratio: float = 1.0 if weight_norm == 0 or p_norm == 0 else weight_norm / (p_norm + group['eps'])

                state['weight_norm'] = weight_norm
                state['adam_norm'] = p_norm
                state['trust_ratio'] = trust_ratio

                if group['adam']:
                    trust_ratio = 1.0

                if group['rectify']:
                    if n_sma >= self.n_sma_threshold:
                        p.addcdiv_(exp_avg, de_nom, value=-step_size * trust_ratio)
                    else:
                        p.add_(exp_avg, alpha=-step_size * trust_ratio)
                else:
                    p.add_(update, alpha=-step_size * trust_ratio)

        return loss

LaProp

Bases: BaseOptimizer

Separating Momentum and Adaptivity in Adam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0004
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
centered bool

bool.

False
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. epsilon value.

1e-15
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/laprop.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class LaProp(BaseOptimizer):
    r"""Separating Momentum and Adaptivity in Adam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param centered: bool.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. epsilon value.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 4e-4,
        betas: BETAS = (0.9, 0.999),
        centered: bool = False,
        steps_before_using_centered: int = 10,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        eps: float = 1e-15,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.steps_before_using_centered: int = steps_before_using_centered
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'centered': centered,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'LaProp'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_lr_1'] = 0.0
                state['exp_avg_lr_2'] = 0.0

                if group['centered']:
                    state['exp_mean_avg_beta2'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, exp_mean_avg_beta2 = (
                    state['exp_avg'],
                    state['exp_avg_sq'],
                    state.get('exp_mean_avg_beta2', None),
                )

                p, grad, exp_avg, exp_avg_sq, exp_mean_avg_beta2 = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, exp_mean_avg_beta2
                )

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                state['exp_avg_lr_1'] = state['exp_avg_lr_1'] * beta1 + (1.0 - beta1) * group['lr']
                state['exp_avg_lr_2'] = state['exp_avg_lr_2'] * beta2 + (1.0 - beta2)

                bias_correction1: float = state['exp_avg_lr_1'] / group['lr'] if group['lr'] != 0.0 else 1.0
                step_size: float = 1.0 / bias_correction1

                de_nom = exp_avg_sq
                if group['centered']:
                    exp_mean_avg_beta2.mul_(beta2).add_(grad, alpha=1.0 - beta2)
                    if group['step'] > self.steps_before_using_centered:
                        de_nom -= exp_mean_avg_beta2.pow(2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=(1.0 - beta1) * group['lr'])

                if group.get('cautious'):
                    update = exp_avg.clone()
                    self.apply_cautious(update, grad)
                else:
                    update = exp_avg

                p.add_(update, alpha=-step_size)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

        return loss

LARS

Bases: BaseOptimizer

Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
weight_decay float

float. weight decay (L2 penalty).

0.0
momentum float

float. momentum.

0.9
dampening float

float. dampening for momentum.

0.0
trust_coefficient float

float. trust_coefficient.

0.001
nesterov bool

bool. enables nesterov momentum.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lars.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class LARS(BaseOptimizer):
    r"""Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param momentum: float. momentum.
    :param dampening: float. dampening for momentum.
    :param trust_coefficient: float. trust_coefficient.
    :param nesterov: bool. enables nesterov momentum.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        momentum: float = 0.9,
        dampening: float = 0.0,
        trust_coefficient: float = 1e-3,
        nesterov: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(dampening, 'dampening', 0.0, 1.0)
        self.validate_non_negative(trust_coefficient, 'trust_coefficient')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'momentum': momentum,
            'dampening': dampening,
            'trust_coefficient': trust_coefficient,
            'nesterov': nesterov,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lars'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if group['momentum'] > 0.0:
                state = self.state[p]

                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = grad.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if p.ndim > 1:
                    param_norm = torch.linalg.norm(p)
                    update_norm = torch.linalg.norm(grad)

                    one = torch.ones_like(param_norm)

                    trust_ratio = torch.where(
                        param_norm > 0.0,
                        torch.where(update_norm > 0.0, (group['trust_coefficient'] * param_norm / update_norm), one),
                        one,
                    )

                    grad.add_(p, alpha=group['weight_decay'])
                    grad.mul_(trust_ratio)

                if group['momentum'] > 0.0:
                    mb = state['momentum_buffer']
                    mb.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])

                    if group['nesterov']:
                        grad.add_(mb, alpha=group['momentum'])
                    else:
                        grad.copy_(mb)

                p.add_(grad, alpha=-group['lr'])

        return loss

Lion

Bases: BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lion.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class Lion(BaseOptimizer):
    r"""Symbolic Discovery of Optimization Algorithms.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        betas: BETAS = (0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lion'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg = state['exp_avg']

                p, grad, exp_avg = self.view_as_real(p, grad, exp_avg)

                if group.get('use_gc'):
                    centralize_gradient(grad, gc_conv_only=False)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                update = exp_avg.clone()

                update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
                exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)

                if group.get('cautious'):
                    self.apply_cautious(update, grad)

                p.add_(update, alpha=-group['lr'])

        return loss

LOMO

Bases: BaseOptimizer

Full Parameter Fine-tuning for Large Language Models with Limited Resources.

Reference : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py Check the usage from here : https://github.com/OpenLMLab/LOMO/blob/main/lomo/src/lomo_trainer.py

Parameters:

Name Type Description Default
model Module

nn.Module. pytorch model.

required
lr float

float. learning rate.

0.001
clip_grad_norm Optional[float]

Optional[float]. clip grad norm.

None
clip_grad_value Optional[float]

Optional[float]. clip grad value.

None
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/lomo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
class LOMO(BaseOptimizer):
    r"""Full Parameter Fine-tuning for Large Language Models with Limited Resources.

    Reference : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py
    Check the usage from here : https://github.com/OpenLMLab/LOMO/blob/main/lomo/src/lomo_trainer.py

    :param model: nn.Module. pytorch model.
    :param lr: float. learning rate.
    :param clip_grad_norm: Optional[float]. clip grad norm.
    :param clip_grad_value: Optional[float]. clip grad value.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-3,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(clip_grad_norm, 'clip_grad_norm')
        self.validate_non_negative(clip_grad_value, 'clip_grad_value')

        self.model = model
        self.lr = lr
        self.clip_grad_norm = clip_grad_norm
        self.clip_grad_value = clip_grad_value
        self.maximize = maximize

        self.local_rank: int = int(os.environ.get('LOCAL_RANK', '0'))

        self.gather_norm: bool = False
        self.grad_norms: List[torch.Tensor] = []
        self.clip_coef: Optional[float] = None

        p0: torch.Tensor = next(iter(self.model.parameters()))

        self.grad_func: Callable[[Any], Any] = (
            self.fuse_update_zero3() if hasattr(p0, 'ds_tensor') else self.fuse_update()
        )

        self.loss_scaler: Optional[DynamicLossScaler] = None
        if p0.dtype == torch.float16:
            if clip_grad_norm is None:
                raise ValueError('loss scaling is recommended to be used with grad norm to get better performance.')

            self.loss_scaler = DynamicLossScaler(init_scale=2 ** 16)  # fmt: skip

        for _, p in self.model.named_parameters():
            if p.requires_grad:
                p.register_hook(self.grad_func)

        defaults: DEFAULTS = {'lr': lr}

        super().__init__(self.model.parameters(), defaults)

    def __str__(self) -> str:
        return 'LOMO'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    def fuse_update(self) -> Callable[[Any], Any]:
        @torch.no_grad()
        def func(x: Any) -> Any:
            for _, p in self.model.named_parameters():
                if not p.requires_grad or p.grad is None:
                    continue

                if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    if self.clip_grad_value is not None and self.clip_grad_value > 0.0:
                        grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is not None:
                        grad_fp32.mul_(self.clip_coef)

                    p_fp32 = p.to(torch.float32)
                    p_fp32.add_(grad_fp32, alpha=-self.lr)
                    p.copy_(p_fp32)

            return x

        return func

    def fuse_update_zero3(self) -> Callable[[Any], Any]:  # pragma: no cover
        @torch.no_grad()
        def func(x: torch.Tensor) -> torch.Tensor:
            for _, p in self.model.named_parameters():
                if p.grad is None:
                    continue

                all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

                if (self.loss_scaler and self.loss_scaler.has_overflow_serial) or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                param_fp32 = p.ds_tensor.to(torch.float32)
                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    one_dim_grad_fp32 = grad_fp32.view(-1)

                    partition_size: int = p.ds_tensor.numel()
                    start: int = partition_size * self.local_rank
                    end: int = min(start + partition_size, grad_fp32.numel())

                    partitioned_grad_fp32 = one_dim_grad_fp32.narrow(0, start, end - start)

                    if self.clip_grad_value is not None:
                        partitioned_grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)

                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
                        partitioned_grad_fp32.mul_(self.clip_coef)

                    partitioned_p = param_fp32.narrow(0, 0, end - start)
                    partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr)

                    p.ds_tensor[: end - start] = partitioned_p  # fmt: skip

            return x

        return func

    def fused_backward(self, loss, lr: float):
        self.lr = lr

        if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is None:
            raise ValueError(
                'clip_grad_norm is not None, but clip_coef is None. '
                'Please call optimizer.grad_norm() before optimizer.fused_backward().'
            )

        if self.loss_scaler:
            loss = loss * self.loss_scaler.loss_scale

        loss.backward()

        self.grad_func(0)

    def grad_norm(self, loss):
        self.gather_norm = True
        self.grad_norms = []

        if self.loss_scaler:
            self.loss_scaler.has_overflow_serial = False
            loss = loss * self.loss_scaler.loss_scale

        loss.backward(retain_graph=True)

        self.grad_func(0)

        if self.loss_scaler and self.loss_scaler.has_overflow_serial:
            self.loss_scaler.update_scale(overflow=True)

            with torch.no_grad():
                for _, p in self.model.named_parameters():
                    p.grad = None
            return

        with torch.no_grad():
            self.grad_norms = torch.stack(self.grad_norms)

            total_norm = torch.norm(self.grad_norms, 2.0)
            self.clip_coef = torch.clamp(float(self.clip_grad_norm) / (total_norm + 1e-6), max=1.0)

        self.gather_norm = False

Lookahead

Bases: BaseOptimizer

k steps forward, 1 step back.

Parameters:

Name Type Description Default
optimizer OPTIMIZER_INSTANCE_OR_CLASS

OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.

required
k int

int. number of lookahead steps.

5
alpha float

float. linear interpolation factor.

0.5
pullback_momentum str

str. change to inner optimizer momentum on interpolation update.

'none'
Source code in pytorch_optimizer/optimizer/lookahead.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class Lookahead(BaseOptimizer):
    r"""k steps forward, 1 step back.

    :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.
    :param k: int. number of lookahead steps.
    :param alpha: float. linear interpolation factor.
    :param pullback_momentum: str. change to inner optimizer momentum on interpolation update.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER_INSTANCE_OR_CLASS,
        k: int = 5,
        alpha: float = 0.5,
        pullback_momentum: str = 'none',
        **kwargs,
    ) -> None:
        self.validate_positive(k, 'k')
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.alpha = alpha
        self.k = k
        self.pullback_momentum = pullback_momentum

        self.state: STATE = defaultdict(dict)

        for group in self.param_groups:
            if 'counter' not in group:
                group['counter'] = 0

            for p in group['params']:
                state = self.state[p]
                state['slow_params'] = torch.empty_like(p)
                state['slow_params'].copy_(p)
                if self.pullback_momentum == 'pullback':
                    state['slow_momentum'] = torch.zeros_like(p)

        self.defaults: DEFAULTS = {
            'lookahead_alpha': alpha,
            'lookahead_k': k,
            'lookahead_pullback_momentum': pullback_momentum,
            **self.optimizer.defaults,
        }

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'alpha': self.alpha,
            'k': self.k,
            'pullback_momentum': self.pullback_momentum,
        }

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    def backup_and_load_cache(self):
        r"""Backup cache parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['backup_params'] = torch.empty_like(p)
                state['backup_params'].copy_(p)
                p.data.copy_(state['slow_params'])

    def clear_and_load_backup(self):
        r"""Load backup parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['backup_params'])
                del state['backup_params']

    def state_dict(self) -> STATE:
        return {'lookahead_state': self.state, 'base_optimizer': self.optimizer.state_dict()}

    def load_state_dict(self, state: STATE) -> None:
        r"""Load state."""
        self.state = state['lookahead_state']
        self.optimizer.load_state_dict(state['base_optimizer'])

    @torch.no_grad()
    def update(self, group: Dict):
        for p in group['params']:
            if p.grad is None:
                continue

            state = self.state[p]

            slow = state['slow_params']

            p.mul_(self.alpha).add_(slow, alpha=1.0 - self.alpha)
            slow.copy_(p)

            if 'momentum_buffer' not in self.optimizer.state[p]:
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

            if self.pullback_momentum == 'pullback':
                internal_momentum = self.optimizer.state[p]['momentum_buffer']
                self.optimizer.state[p]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_(
                    state['slow_momentum'], alpha=1.0 - self.alpha
                )
                state['slow_momentum'] = self.optimizer.state[p]['momentum_buffer']
            elif self.pullback_momentum == 'reset':
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = self.optimizer.step(closure)
        for group in self.param_groups:
            group['counter'] += 1
            if group['counter'] >= self.k:
                group['counter'] = 0
                self.update(group)
        return loss

backup_and_load_cache()

Backup cache parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
81
82
83
84
85
86
87
88
def backup_and_load_cache(self):
    r"""Backup cache parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            state['backup_params'] = torch.empty_like(p)
            state['backup_params'].copy_(p)
            p.data.copy_(state['slow_params'])

clear_and_load_backup()

Load backup parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
90
91
92
93
94
95
96
def clear_and_load_backup(self):
    r"""Load backup parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            p.data.copy_(state['backup_params'])
            del state['backup_params']

load_state_dict(state)

Load state.

Source code in pytorch_optimizer/optimizer/lookahead.py
101
102
103
104
def load_state_dict(self, state: STATE) -> None:
    r"""Load state."""
    self.state = state['lookahead_state']
    self.optimizer.load_state_dict(state['base_optimizer'])

LookSAM

Bases: BaseOptimizer

Towards Efficient and Scalable Sharpness-Aware Minimization.

Example:

Here's an example::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = LookSAM(model.parameters(), base_optimizer)

    for input, output in data:
        # first forward-backward pass

        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward pass
        # make sure to do a full forward pass
        loss_function(output, model(input)).backward()
        optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = LookSAM(model.parameters(), base_optimizer)

    def closure():
        loss = loss_function(output, model(input))
        loss.backward()
        return loss

    for input, output in data:
        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.step(closure)
        optimizer.zero_grad()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

OPTIMIZER. base optimizer.

required
rho float

float. size of the neighborhood for computing the max loss.

0.1
k int

int. lookahead step.

10
alpha float

float. lookahead blending alpha.

0.7
adaptive bool

bool. element-wise Adaptive SAM.

False
use_gc bool

bool. perform gradient centralization, GCSAM variant.

False
perturb_eps float

float. eps for perturbation.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
class LookSAM(BaseOptimizer):
    r"""Towards Efficient and Scalable Sharpness-Aware Minimization.

    Example:
    -------
        Here's an example::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = LookSAM(model.parameters(), base_optimizer)

            for input, output in data:
                # first forward-backward pass

                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                # make sure to do a full forward pass
                loss_function(output, model(input)).backward()
                optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = LookSAM(model.parameters(), base_optimizer)

            def closure():
                loss = loss_function(output, model(input))
                loss.backward()
                return loss

            for input, output in data:
                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.step(closure)
                optimizer.zero_grad()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: OPTIMIZER. base optimizer.
    :param rho: float. size of the neighborhood for computing the max loss.
    :param k: int. lookahead step.
    :param alpha: float. lookahead blending alpha.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param use_gc: bool. perform gradient centralization, GCSAM variant.
    :param perturb_eps: float. eps for perturbation.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        rho: float = 0.1,
        k: int = 10,
        alpha: float = 0.7,
        adaptive: bool = False,
        use_gc: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')
        self.validate_positive(k, 'k')
        self.validate_range(alpha, 'alpha', 0.0, 1.0, '()')
        self.validate_non_negative(perturb_eps, 'perturb_eps')

        self.k = k
        self.alpha = alpha
        self.use_gc = use_gc
        self.perturb_eps = perturb_eps

        defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
        defaults.update(kwargs)

        super().__init__(params, defaults)

        self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'LookSAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    def get_step(self):
        return (
            self.param_groups[0]['step']
            if 'step' in self.param_groups[0]
            else next(iter(self.base_optimizer.state.values()))['step'] if self.base_optimizer.state else 0
        )

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False) -> None:
        if self.get_step() % self.k != 0:
            return

        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)

        for group in self.param_groups:
            scale = group['rho'] / grad_norm

            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                grad = p.grad
                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                self.state[p]['old_p'] = p.clone()
                self.state[f'old_grad_p_{i}']['old_grad_p'] = grad.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        step = self.get_step()

        for group in self.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                grad = p.grad
                grad_norm = grad.norm(p=2)

                if step % self.k == 0:
                    old_grad_p = self.state[f'old_grad_p_{i}']['old_grad_p']

                    g_grad_norm = old_grad_p / old_grad_p.norm(p=2)
                    g_s_grad_norm = grad / grad_norm

                    self.state[f'gv_{i}']['gv'] = torch.sub(
                        grad, grad_norm * torch.sum(g_grad_norm * g_s_grad_norm) * g_grad_norm
                    )
                else:
                    gv = self.state[f'gv_{i}']['gv']
                    grad.add_(grad_norm / (gv.norm(p=2) + 1e-8) * gv, alpha=self.alpha)

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        with torch.enable_grad():
            closure()

        self.second_step()

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

MADGRAD

Bases: BaseOptimizer

A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
eps float

float. term added to the denominator to improve numerical stability.

1e-06
weight_decay float

float. weight decay (L2 penalty). MADGRAD optimizer requires less weight decay than other methods, often as little as zero. On sparse problems both weight_decay and momentum should be set to 0.

0.0
weight_decouple bool

float. Apply AdamW style decoupled weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/madgrad.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class MADGRAD(BaseOptimizer):
    r"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param weight_decay: float. weight decay (L2 penalty).
        MADGRAD optimizer requires less weight decay than other methods, often as little as zero.
        On sparse problems both weight_decay and momentum should be set to 0.
    :param weight_decouple: float. Apply AdamW style decoupled weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'momentum': momentum,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MADGRAD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if group['momentum'] > 0.0 and grad.is_sparse:
                raise NoSparseGradientError(str(self), note='momentum > 0.0')

            if group['weight_decay'] > 0.0 and not group['weight_decouple'] and grad.is_sparse:
                raise NoSparseGradientError(str(self), note='weight_decay')

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            state['grad_sum_sq'] = torch.zeros_like(p)
            state['s'] = torch.zeros_like(p)

            if group['momentum'] > 0.0:
                state['x0'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        if 'k' not in self.state:
            self.state['k'] = torch.tensor([0], dtype=torch.long, requires_grad=False)

        for group in self.param_groups:
            if self.state['k'] == 0:
                self.init_group(group)

            weight_decay, momentum, eps = group['weight_decay'], group['momentum'], group['eps']
            lr: float = group['lr'] + eps

            _lambda = lr * math.pow(self.state['k'] + 1, 0.5)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_sum_sq, s = state['grad_sum_sq'], state['s']
                if weight_decay > 0.0 and not group['weight_decouple']:
                    grad.add_(p, alpha=weight_decay)

                if grad.is_sparse:
                    grad = grad.coalesce()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    rms_masked_values = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
                    x0_masked_values = p_masked._values().addcdiv(s_masked._values(), rms_masked_values, value=1)

                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=_lambda)
                    grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)

                    rms_masked_values = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
                    if eps == 0.0:
                        rms_masked_values[rms_masked_values == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)
                    s_masked._values().add_(grad._values(), alpha=_lambda)

                    p_kp1_masked_values = x0_masked_values.addcdiv(s_masked._values(), rms_masked_values, value=-1)

                    p_masked._values().add_(p_kp1_masked_values, alpha=-1)
                    p.data.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0.0:
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.addcdiv(s, rms, value=1)
                    else:
                        x0 = state['x0']

                    grad_sum_sq.addcmul_(grad, grad, value=_lambda)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    if eps == 0.0:
                        rms[rms == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)

                    if weight_decay > 0.0 and group['weight_decouple']:
                        p_old = p.clone()

                    if momentum == 0.0:
                        p.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)
                        p.mul_(momentum).add_(z, alpha=1.0 - momentum)

                    if weight_decay > 0.0 and group['weight_decouple']:
                        p.add_(p_old, alpha=-lr * weight_decay)

        self.state['k'].add_(1)

        return loss

MARS

Bases: BaseOptimizer

Unleashing the Power of Variance Reduction for Training Large Models.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.003
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.99)
gamma float

float. the scaling parameter that controls the strength of gradient correction.

0.025
mars_type MARS_TYPE

MARS TYPE. type of MARS. adamw, lion, shampoo are supported.

'adamw'
optimize_1d bool

bool. whether MARS should optimize 1D parameters.

False
lr_1d bool

float. learning rate for AdamW when optimize_1d is set to False.

0.003
betas_1d BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace for 1d.

(0.9, 0.95)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decay_1d float

float. weight decay for 1d.

0.1
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/mars.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class MARS(BaseOptimizer):
    r"""Unleashing the Power of Variance Reduction for Training Large Models.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param gamma: float. the scaling parameter that controls the strength of gradient correction.
    :param mars_type: MARS TYPE. type of MARS. `adamw`, `lion`, `shampoo` are supported.
    :param optimize_1d: bool. whether MARS should optimize 1D parameters.
    :param lr_1d: float. learning rate for AdamW when optimize_1d is set to False.
    :param betas_1d: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
        for 1d.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decay_1d: float. weight decay for 1d.
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 3e-3,
        betas: BETAS = (0.95, 0.99),
        gamma: float = 0.025,
        mars_type: MARS_TYPE = 'adamw',
        optimize_1d: bool = False,
        lr_1d: bool = 3e-3,
        betas_1d: BETAS = (0.9, 0.95),
        weight_decay: float = 0.0,
        weight_decay_1d: float = 1e-1,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        cautious: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(lr_1d)
        self.validate_betas(betas)
        self.validate_betas(betas_1d)
        self.validate_options(mars_type, 'mars_type', ['adamw', 'lion', 'shampoo'])
        self.validate_non_negative(gamma, 'gamma')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(weight_decay_1d, 'weight_decay_1d')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'lr_1d': lr_1d,
            'lr_1d_factor': lr_1d / lr,
            'betas': betas,
            'betas_1d': betas_1d,
            'mars_type': mars_type,
            'gamma': gamma,
            'optimize_1d': optimize_1d,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'cautious': cautious,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MARS'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['last_grad'] = torch.zeros_like(p)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    def optimize_mixed(
        self,
        grad: torch.Tensor,
        last_grad: torch.Tensor,
        exp_avg: torch.Tensor,
        exp_avg_sq: torch.Tensor,
        max_exp_avg_sq: Optional[torch.Tensor],
        betas: BETAS,
        gamma: float,
        mars_type: MARS_TYPE,
        is_grad_2d: bool,
        step: int,
        ams_bound: bool,
        cautious: bool,
        eps: float,
    ) -> torch.Tensor:
        beta1, beta2 = betas

        c_t = (grad - last_grad).mul_(gamma * (beta1 / (1.0 - beta1))).add_(grad)
        c_t_norm = torch.norm(c_t)
        if c_t_norm > 1.0:
            c_t.div_(c_t_norm)

        exp_avg.mul_(beta1).add_(c_t, alpha=1.0 - beta1)

        update = exp_avg.clone()
        if cautious:
            self.apply_cautious(update, grad)

        if mars_type == 'adamw' or (mars_type == 'shampoo' and not is_grad_2d):
            exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1.0 - beta2)

            bias_correction1: float = self.debias(beta1, step)
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

            de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
            de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

            return update.div_(de_nom)

        if mars_type == 'lion':
            return update.sign_()

        factor: float = math.sqrt(max(1.0, grad.size(0) / grad.size(1)))

        update = update.view(update.size(0), -1)

        return zero_power_via_newton_schulz_5(update.mul_(1.0 / (1.0 - beta1)), eps=eps).mul_(factor).view_as(grad)

    def optimize_1d(
        self,
        grad: torch.Tensor,
        exp_avg: torch.Tensor,
        exp_avg_sq: torch.Tensor,
        max_exp_avg_sq: Optional[torch.Tensor],
        betas: BETAS,
        step: int,
        ams_bound: bool,
        cautious: bool,
        eps: float,
    ) -> torch.Tensor:
        beta1, beta2 = betas

        bias_correction1: float = self.debias(beta1, step)
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, step))

        exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

        update = exp_avg.clone()

        if cautious:
            self.apply_cautious(update, grad)

        de_nom = self.apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps)
        de_nom.div_(bias_correction2_sq).mul_(bias_correction1)

        return update.div_(de_nom)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, last_grad = state['exp_avg'], state['exp_avg_sq'], state['last_grad']

                p, grad, exp_avg, exp_avg_sq, last_grad = self.view_as_real(p, grad, exp_avg, exp_avg_sq, last_grad)

                is_grad_2d: bool = grad.ndim >= 2
                step_size: float = (
                    group['lr'] if group['optimize_1d'] or is_grad_2d else group['lr'] * group['lr_1d_factor']
                )

                if group['optimize_1d'] or is_grad_2d:
                    update = self.optimize_mixed(
                        grad,
                        last_grad,
                        exp_avg,
                        exp_avg_sq,
                        state.get('max_exp_avg_sq', None),
                        group['betas'],
                        group['gamma'],
                        group['mars_type'],
                        is_grad_2d,
                        group['step'],
                        group['ams_bound'],
                        group.get('cautious'),
                        group['eps'],
                    )
                else:
                    update = self.optimize_1d(
                        grad,
                        exp_avg,
                        exp_avg_sq,
                        state.get('max_exp_avg_sq', None),
                        group['betas_1d'],
                        group['step'],
                        group['ams_bound'],
                        group.get('cautious'),
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=step_size,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(update, alpha=-step_size)

                state['last_grad'] = torch.view_as_complex(grad) if torch.is_complex(state['last_grad']) else grad

        return loss

MSVAG

Bases: BaseOptimizer

Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. Moving average (momentum) constant (scalar tensor or float value).

0.9
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/msvag.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class MSVAG(BaseOptimizer):
    r"""Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. Moving average (momentum) constant (scalar tensor or float value).
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        beta: float = 0.9,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'beta': beta}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MSVAG'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @staticmethod
    def get_rho(beta_power: float, beta: float) -> float:
        r"""Get rho."""
        rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
        rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
        return min(rho, 0.9999)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta: float = group['beta']
            beta_power: float = beta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)
                exp_avg_sq.mul_(beta).addcmul_(grad, grad, value=1.0 - beta)

                m = exp_avg.div(beta_power)
                v = exp_avg_sq.div(beta_power)

                rho: float = self.get_rho(beta_power, beta)

                m_p2 = m.pow(2)
                s = (v - m_p2).div_(1.0 - rho)

                factor = m_p2.div(m_p2 + rho * s)
                torch.nan_to_num(factor, nan=0.0, out=factor)
                factor.clamp_(0.0, 1.0)

                p.add_(m * factor, alpha=-group['lr'])

        return loss

get_rho(beta_power, beta) staticmethod

Get rho.

Source code in pytorch_optimizer/optimizer/msvag.py
53
54
55
56
57
58
@staticmethod
def get_rho(beta_power: float, beta: float) -> float:
    r"""Get rho."""
    rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
    rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
    return min(rho, 0.9999)

Muon

Bases: BaseOptimizer

Momentum Orthogonalized by Newton-schulz.

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.

Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. the parameters to be optimized by Muon.

required
lr float

float. learning rate.

0.02
momentum float

float. the momentum used by the internal SGD.

0.95
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
betas BETAS

The betas for the internal AdamW.

(0.9, 0.95)
nesterov bool

bool. whether to use nesterov momentum.

True
ns_steps int

int. the number of Newton-Schulz iterations to run. (5 is probably always enough)

5
use_adjusted_lr bool

bool. whether to use adjusted learning rate, which is from the Moonlight. reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

False
adamw_params Optional[PARAMETERS]

Optional[PARAMETERS] The parameters to be optimized by AdamW. Any parameters in muon_params which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. It'd be better to create AdamW optimizer instead of using this.

None
adamw_lr float

float. The learning rate for the internal AdamW.

0.0003
adamw_wd float

float. The weight decay for the internal AdamW.

0.0
adamw_eps float

float. The epsilon for the internal AdamW.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/muon.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class Muon(BaseOptimizer):
    r"""Momentum Orthogonalized by Newton-schulz.

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, in which
    each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each
    update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU.

    Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and
    scalar or vector parameters should be optimized using AdamW.

    Some warnings:
    - We believe this optimizer is unlikely to work well for training with small batch size.
    - We believe it may not work well for fine-tuning pretrained models, but we haven't tested this.

    :param params: PARAMETERS. the parameters to be optimized by Muon.
    :param lr: float. learning rate.
    :param momentum: float. the momentum used by the internal SGD.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param betas: The betas for the internal AdamW.
    :param nesterov: bool. whether to use nesterov momentum.
    :param ns_steps: int. the number of Newton-Schulz iterations to run. (5 is probably always enough)
    :param use_adjusted_lr: bool. whether to use adjusted learning rate, which is from the Moonlight.
        reference: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
    :param adamw_params: Optional[PARAMETERS] The parameters to be optimized by AdamW. Any parameters in `muon_params`
        which are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. It'd be
        better to create AdamW optimizer instead of using this.
    :param adamw_lr: float. The learning rate for the internal AdamW.
    :param adamw_wd: float. The weight decay for the internal AdamW.
    :param adamw_eps: float. The epsilon for the internal AdamW.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 2e-2,
        momentum: float = 0.95,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        betas: BETAS = (0.9, 0.95),
        nesterov: bool = True,
        ns_steps: int = 5,
        use_adjusted_lr: bool = False,
        adamw_params: Optional[PARAMETERS] = None,
        adamw_lr: float = 3e-4,
        adamw_wd: float = 0.0,
        adamw_eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(adamw_lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_positive(ns_steps, 'ns_steps')
        self.validate_betas(betas)
        self.validate_non_negative(adamw_wd, 'adamw_wd')
        self.validate_non_negative(adamw_eps, 'adamw_eps')

        params = self.get_parameters(params)
        adamw_params = self.get_parameters(adamw_params) if adamw_params is not None else []
        params.extend(adamw_params)

        self.world_size: int = int(os.environ.get('WORLD_SIZE', '1'))
        self.rank: int = int(os.environ.get('RANK', '0'))
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'nesterov': nesterov,
            'ns_steps': ns_steps,
            'use_adjusted_lr': use_adjusted_lr,
            'adamw_lr': adamw_lr,
            'adamw_lr_ratio': adamw_lr / lr,
            'adamw_betas': betas,
            'adamw_wd': adamw_wd,
            'adamw_eps': adamw_eps,
        }

        super().__init__(params, defaults)

        self.set_muon_state(params, adamw_params)

    def __str__(self) -> str:
        return 'Muon'

    @staticmethod
    def get_parameters(params: PARAMETERS) -> List[torch.Tensor]:
        if isinstance(params, list) and isinstance(params[0], torch.Tensor):
            return params

        new_params = []
        for group in params:
            if isinstance(group, dict) and 'params' in group:
                new_params.extend(list(group['params']))
            else:
                new_params.append(group)

        return new_params

    def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS) -> None:
        r"""Set use_muon flag."""
        for p in params:
            self.state[p]['use_muon'] = p.ndim >= 2

        for p in adamw_params:
            self.state[p]['use_muon'] = False

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @staticmethod
    def get_adjusted_lr(lr: float, param_shape: Tuple[float, ...], use_adjusted_lr: bool = False) -> float:
        r"""Get the adjust learning rate."""
        output_shape, *input_shape = param_shape
        input_shape = math.prod(input_shape)

        ratio: float = (
            math.pow(max(1.0, output_shape / input_shape), 0.5)
            if use_adjusted_lr
            else 0.2 * math.sqrt(max(output_shape, input_shape))
        )

        return lr * ratio

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            params = []
            for p in group['params']:
                if p.grad is not None and self.state[p]['use_muon']:
                    if p.grad.is_sparse:
                        raise NoSparseGradientError(str(self))
                    if torch.is_complex(p):
                        raise NoComplexParameterError(str(self))
                    params.append(p)

            if len(params) == 0:
                continue

            momentum = group['momentum']

            total_params: int = sum(p.numel() for p in params)
            updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16)
            curr_idx: int = 0

            for i, p in enumerate(params):
                if i % self.world_size != self.rank:
                    curr_idx += p.numel()
                    continue

                grad = p.grad
                if grad.ndim > 2:
                    grad = grad.view(grad.size(0), -1)

                state = self.state[p]
                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(grad)

                buf = state['momentum_buffer']
                buf.lerp_(grad, weight=1.0 - momentum)

                grad = grad.lerp_(buf, momentum) if group['nesterov'] else buf

                grad = zero_power_via_newton_schulz_5(grad, num_steps=group['ns_steps']).flatten()

                updates_flat[curr_idx:curr_idx + p.numel()] = grad  # fmt: skip

            if self.world_size > 1:  # pragma: no cover
                all_reduce(updates_flat, op=ReduceOp.SUM)

            curr_idx: int = 0
            for p in params:
                g = updates_flat[curr_idx:curr_idx + p.numel()].view_as(p)  # fmt: skip

                self.apply_weight_decay(
                    p,
                    grad=g,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                lr: float = self.get_adjusted_lr(group['lr'], p.size(), group['use_adjusted_lr'])

                p.add_(g, alpha=-lr)
                curr_idx += p.numel()

            params = [p for p in group['params'] if p.grad is not None and not self.state[p]['use_muon']]

            lr: float = group['adamw_lr_ratio'] * group['lr']
            beta1, beta2 = group['adamw_betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            scale: float = bias_correction1 / bias_correction2 ** 0.5  # fmt: skip
            step_size: float = lr / scale

            for p in params:
                grad = p.grad
                state = self.state[p]
                if 'moment1' not in state:
                    state['moment1'] = torch.zeros_like(grad)
                    state['moment2'] = torch.zeros_like(grad)

                buf1, buf2 = state['moment1'], state['moment2']
                buf1.lerp_(grad, weight=1.0 - beta1)
                buf2.lerp_(grad.square(), weight=1.0 - beta2)

                update = buf1 / buf2.sqrt().add_(group['adamw_eps'])

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=lr,
                    weight_decay=group['adamw_wd'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                p.add_(update, alpha=-step_size)

        return loss

get_adjusted_lr(lr, param_shape, use_adjusted_lr=False) staticmethod

Get the adjust learning rate.

Source code in pytorch_optimizer/optimizer/muon.py
129
130
131
132
133
134
135
136
137
138
139
140
141
@staticmethod
def get_adjusted_lr(lr: float, param_shape: Tuple[float, ...], use_adjusted_lr: bool = False) -> float:
    r"""Get the adjust learning rate."""
    output_shape, *input_shape = param_shape
    input_shape = math.prod(input_shape)

    ratio: float = (
        math.pow(max(1.0, output_shape / input_shape), 0.5)
        if use_adjusted_lr
        else 0.2 * math.sqrt(max(output_shape, input_shape))
    )

    return lr * ratio

set_muon_state(params, adamw_params)

Set use_muon flag.

Source code in pytorch_optimizer/optimizer/muon.py
118
119
120
121
122
123
124
def set_muon_state(self, params: PARAMETERS, adamw_params: PARAMETERS) -> None:
    r"""Set use_muon flag."""
    for p in params:
        self.state[p]['use_muon'] = p.ndim >= 2

    for p in adamw_params:
        self.state[p]['use_muon'] = False

Nero

Bases: BaseOptimizer

Learning by Turning: Neural Architecture Aware Optimisation.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. coefficients used for computing running averages of gradient and the squared hessian trace.

0.999
constraints bool

bool.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/nero.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class Nero(BaseOptimizer):
    """Learning by Turning: Neural Architecture Aware Optimisation.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param constraints: bool.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 0.01,
        beta: float = 0.999,
        constraints: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'beta': beta, 'constraints': constraints, 'eps': eps}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Nero'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p).add_(group['eps']))

                state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))

                state['scale'] = neuron_norm(p).mean()
                if state['scale'] == 0.0:
                    state['scale'] = 0.01

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            bias_correction: float = self.debias(group['beta'], group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_norm = neuron_norm(grad)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(group['beta']).addcmul_(grad_norm, grad_norm, value=1.0 - group['beta'])

                grad_normed = grad / ((exp_avg_sq / bias_correction).sqrt_().add_(group['eps']))
                torch.nan_to_num(grad_normed, nan=0.0, out=grad_normed)

                p.add_(grad_normed, alpha=-group['lr'] * state['scale'])

                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p).add_(group['eps']))

        return loss

NovoGrad

Bases: BaseOptimizer

Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.98)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
grad_averaging bool

bool. multiply ck (1 - momentum).

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/novograd.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class NovoGrad(BaseOptimizer):
    r"""Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param grad_averaging: bool. multiply ck (1 - momentum).
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.95, 0.98),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        grad_averaging: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'grad_averaging': grad_averaging,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'NovoGrad'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            grad_p2 = grad.pow(2).sum()

            if len(state) == 0:
                state['moments'] = grad.div(grad_p2.sqrt().add_(group['eps'])) + group['weight_decay'] * p
                state['grads_ema'] = grad_p2

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                group.get('adam_debias', False),
                step_size=group['lr'] * bias_correction2_sq,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grads_ema, moments = state['grads_ema'], state['moments']

                grads_ema.mul_(beta2).add_(grad.pow(2).sum(), alpha=1.0 - beta2)

                de_nom = grads_ema.sqrt().add_(group['eps'])
                grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['grad_averaging']:
                    grad.mul_(1.0 - beta1)

                moments.mul_(beta1).add_(grad)

                p.add_(moments, alpha=-step_size)

        return loss

OrthoGrad

Bases: BaseOptimizer

Grokking at the Edge of Numerical Stability.

A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update.

Parameters:

Name Type Description Default
optimizer OPTIMIZER_INSTANCE_OR_CLASS

OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.

required
Source code in pytorch_optimizer/optimizer/orthograd.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class OrthoGrad(BaseOptimizer):
    r"""Grokking at the Edge of Numerical Stability.

    A wrapper optimizer that projects gradients to be orthogonal to the current parameters before performing an update.

    :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.
    """

    def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}
        self.eps: float = 1e-30

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self.defaults: DEFAULTS = self.optimizer.defaults

    def __str__(self) -> str:
        return 'OrthoGrad'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @property
    def state(self) -> STATE:
        return self.optimizer.state

    def state_dict(self) -> STATE:
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: STATE) -> None:
        self.optimizer.load_state_dict(state_dict)

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def apply_orthogonal_gradients(self, params) -> None:
        for p in params:
            if p.grad is None or p.grad.is_sparse or torch.is_complex(p):
                continue

            w = p.view(-1)
            g = p.grad.view(-1)

            proj = torch.dot(w, g).div_(torch.dot(w, w).add_(self.eps))
            g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
            g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(self.eps)))

            p.grad.copy_(g_ortho_scaled.view_as(p.grad))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        for group in self.param_groups:
            self.apply_orthogonal_gradients(group['params'])
        return self.optimizer.step(closure)

PAdam

Bases: BaseOptimizer

Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
partial float

float. partially adaptive parameter.

0.25
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/padam.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class PAdam(BaseOptimizer):
    """Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param partial: float. partially adaptive parameter.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        partial: float = 0.25,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(partial, 'partial', 0.0, 1.0, range_type='(]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'partial': partial,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom ** (group['partial'] * 2), value=-step_size)

        return loss

PCGrad

Bases: BaseOptimizer

Gradient Surgery for Multi-Task Learning.

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer: optimizer instance.

required
reduction str

str. reduction method.

'mean'
Source code in pytorch_optimizer/optimizer/pcgrad.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class PCGrad(BaseOptimizer):
    r"""Gradient Surgery for Multi-Task Learning.

    :param optimizer: Optimizer: optimizer instance.
    :param reduction: str. reduction method.
    """

    def __init__(self, optimizer: Optimizer, reduction: str = 'mean'):
        self.validate_options(reduction, 'reduction', ['mean', 'sum'])

        self.optimizer = optimizer
        self.reduction = reduction

    @torch.no_grad()
    def init_group(self):
        self.zero_grad()

    def zero_grad(self):
        return self.optimizer.zero_grad(set_to_none=True)

    def step(self):
        return self.optimizer.step()

    def set_grad(self, grads: List[torch.Tensor]) -> None:
        idx: int = 0
        for group in self.optimizer.param_groups:
            for p in group['params']:
                p.grad = grads[idx]
                idx += 1

    def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
        r"""Get the gradient of the parameters of the network with specific objective."""
        grad, shape, has_grad = [], [], []
        for group in self.optimizer.param_groups:
            for p in group['params']:
                if p.grad is None:
                    shape.append(p.shape)
                    grad.append(torch.zeros_like(p, device=p.device))
                    has_grad.append(torch.zeros_like(p, device=p.device))
                    continue

                shape.append(p.grad.shape)
                grad.append(p.grad.clone())
                has_grad.append(torch.ones_like(p, device=p.device))

        return grad, shape, has_grad

    def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
        r"""Pack the gradient of the parameters of the network for each objective.

        :param objectives: Iterable[nn.Module]. a list of objectives.
        :return: torch.Tensor. packed gradients.
        """
        grads, shapes, has_grads = [], [], []
        for objective in objectives:
            self.optimizer.zero_grad(set_to_none=True)
            objective.backward(retain_graph=True)

            grad, shape, has_grad = self.retrieve_grad()

            grads.append(flatten_grad(grad))
            has_grads.append(flatten_grad(has_grad))
            shapes.append(shape)

        return grads, shapes, has_grads

    def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
        r"""Project conflicting.

        :param grads: a list of the gradient of the parameters.
        :param has_grads: a list of mask represent whether the parameter has gradient.
        :return: torch.Tensor. merged gradients.
        """
        shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

        pc_grad: List[torch.Tensor] = deepcopy(grads)
        for i, g_i in enumerate(pc_grad):
            random.shuffle(grads)
            for g_j in grads:
                g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

        merged_grad: torch.Tensor = torch.zeros_like(grads[0])

        shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
        if self.reduction == 'mean':
            merged_grad[shared] = shared_pc_gradients.mean(dim=0)
        else:
            merged_grad[shared] = shared_pc_gradients.sum(dim=0)

        merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

        return merged_grad

    def pc_backward(self, objectives: Iterable[nn.Module]) -> None:
        r"""Calculate the gradient of the parameters.

        :param objectives: Iterable[nn.Module]. a list of objectives.
        """
        grads, shapes, has_grads = self.pack_grad(objectives)

        pc_grad = self.project_conflicting(grads, has_grads)
        pc_grad = un_flatten_grad(pc_grad, shapes[0])

        self.set_grad(pc_grad)

pack_grad(objectives)

Pack the gradient of the parameters of the network for each objective.

Parameters:

Name Type Description Default
objectives Iterable

Iterable[nn.Module]. a list of objectives.

required

Returns:

Type Description
Tuple[List[Tensor], List[List[int]], List[Tensor]]

torch.Tensor. packed gradients.

Source code in pytorch_optimizer/optimizer/pcgrad.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
    r"""Pack the gradient of the parameters of the network for each objective.

    :param objectives: Iterable[nn.Module]. a list of objectives.
    :return: torch.Tensor. packed gradients.
    """
    grads, shapes, has_grads = [], [], []
    for objective in objectives:
        self.optimizer.zero_grad(set_to_none=True)
        objective.backward(retain_graph=True)

        grad, shape, has_grad = self.retrieve_grad()

        grads.append(flatten_grad(grad))
        has_grads.append(flatten_grad(has_grad))
        shapes.append(shape)

    return grads, shapes, has_grads

pc_backward(objectives)

Calculate the gradient of the parameters.

Parameters:

Name Type Description Default
objectives Iterable[Module]

Iterable[nn.Module]. a list of objectives.

required
Source code in pytorch_optimizer/optimizer/pcgrad.py
124
125
126
127
128
129
130
131
132
133
134
def pc_backward(self, objectives: Iterable[nn.Module]) -> None:
    r"""Calculate the gradient of the parameters.

    :param objectives: Iterable[nn.Module]. a list of objectives.
    """
    grads, shapes, has_grads = self.pack_grad(objectives)

    pc_grad = self.project_conflicting(grads, has_grads)
    pc_grad = un_flatten_grad(pc_grad, shapes[0])

    self.set_grad(pc_grad)

project_conflicting(grads, has_grads)

Project conflicting.

Parameters:

Name Type Description Default
grads List[Tensor]

a list of the gradient of the parameters.

required
has_grads List[Tensor]

a list of mask represent whether the parameter has gradient.

required

Returns:

Type Description
Tensor

torch.Tensor. merged gradients.

Source code in pytorch_optimizer/optimizer/pcgrad.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
    r"""Project conflicting.

    :param grads: a list of the gradient of the parameters.
    :param has_grads: a list of mask represent whether the parameter has gradient.
    :return: torch.Tensor. merged gradients.
    """
    shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

    pc_grad: List[torch.Tensor] = deepcopy(grads)
    for i, g_i in enumerate(pc_grad):
        random.shuffle(grads)
        for g_j in grads:
            g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
            if g_i_g_j < 0:
                pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

    merged_grad: torch.Tensor = torch.zeros_like(grads[0])

    shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
    if self.reduction == 'mean':
        merged_grad[shared] = shared_pc_gradients.mean(dim=0)
    else:
        merged_grad[shared] = shared_pc_gradients.sum(dim=0)

    merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

    return merged_grad

retrieve_grad()

Get the gradient of the parameters of the network with specific objective.

Source code in pytorch_optimizer/optimizer/pcgrad.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
    r"""Get the gradient of the parameters of the network with specific objective."""
    grad, shape, has_grad = [], [], []
    for group in self.optimizer.param_groups:
        for p in group['params']:
            if p.grad is None:
                shape.append(p.shape)
                grad.append(torch.zeros_like(p, device=p.device))
                has_grad.append(torch.zeros_like(p, device=p.device))
                continue

            shape.append(p.grad.shape)
            grad.append(p.grad.clone())
            has_grad.append(torch.ones_like(p, device=p.device))

    return grad, shape, has_grad

PID

Bases: BaseOptimizer

A PID Controller Approach for Stochastic Optimization of Deep Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
dampening float

float. dampening for momentum.

0.0
derivative float

float. D part of the PID.

10.0
integral float

float. I part of the PID.

5.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/pid.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class PID(BaseOptimizer):
    r"""A PID Controller Approach for Stochastic Optimization of Deep Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param dampening: float. dampening for momentum.
    :param derivative: float. D part of the PID.
    :param integral: float. I part of the PID.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        derivative: float = 10.0,
        integral: float = 5.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(derivative, 'derivative')
        self.validate_non_negative(integral, 'integral')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'dampening': dampening,
            'derivative': derivative,
            'integral': integral,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PID'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0 and group['momentum'] > 0.0:
                state['grad_buffer'] = torch.zeros_like(p)
                state['i_buffer'] = torch.zeros_like(p)
                state['d_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                g_buf, i_buf, d_buf = (
                    state.get('grad_buffer', None),
                    state.get('i_buffer', None),
                    state.get('d_buffer', None),
                )

                p, grad, g_buf, i_buf, d_buf = self.view_as_real(p, grad, g_buf, i_buf, d_buf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['momentum'] > 0.0:
                    i_buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])
                    d_buf.mul_(group['momentum'])

                    if group['step'] > 1:
                        d_buf.add_(grad - g_buf, alpha=1.0 - group['momentum'])
                        g_buf.copy_(grad)

                    grad.add_(i_buf, alpha=group['integral']).add_(d_buf, alpha=group['derivative'])

                p.add_(grad, alpha=-group['lr'])

        return loss

PNM

Bases: BaseOptimizer

Positive-Negative Momentum Optimizers.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use weight_decouple.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/pnm.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class PNM(BaseOptimizer):
    r"""Positive-Negative Momentum Optimizers.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use weight_decouple.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas, beta_range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PNM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['pos_momentum'] = torch.zeros_like(p)
                state['neg_momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            beta1_p2: float = beta1 ** 2  # fmt: skip
            noise_norm: float = math.sqrt((1 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if group['step'] % 2 == 1:
                    pos_momentum, neg_momentum = state['pos_momentum'], state['neg_momentum']
                else:
                    neg_momentum, pos_momentum = state['pos_momentum'], state['neg_momentum']

                p, grad, pos_momentum, neg_momentum = self.view_as_real(p, grad, pos_momentum, neg_momentum)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                pos_momentum.mul_(beta1_p2).add_(grad, alpha=1.0 - beta1_p2)

                delta_p = pos_momentum.mul(1.0 + beta2).add_(neg_momentum, alpha=-beta2).mul_(1.0 / noise_norm)

                p.add_(delta_p, alpha=-group['lr'])

        return loss

Prodigy

Bases: BaseOptimizer

An Expeditiously Adaptive Parameter-Free Learner.

Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. betas.

(0.9, 0.999)
beta3 Optional[float]

float. coefficients for computing the Prodigy step-size using running averages. If set to None, uses the value of square root of beta2.

None
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
d_coef float

float. Coefficient in the expression for the estimate of d.

1.0
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use AdamW style weight decay.

True
fixed_decay bool

bool. fix weight decay.

False
bias_correction bool

bool. turn on Adam's bias correction.

False
safeguard_warmup bool

bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.

False
eps Optional[float]

float. term added to the denominator to improve numerical stability. when eps is None, use atan2 rather than epsilon and division for parameter updates.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/prodigy.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class Prodigy(BaseOptimizer):
    r"""An Expeditiously Adaptive Parameter-Free Learner.

        Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. betas.
    :param beta3: float. coefficients for computing the Prodigy step-size using running averages. If set to None,
        uses the value of square root of beta2.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param d_coef: float. Coefficient in the expression for the estimate of d.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use AdamW style weight decay.
    :param fixed_decay: bool. fix weight decay.
    :param bias_correction: bool. turn on Adam's bias correction.
    :param safeguard_warmup: bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.
    :param eps: float. term added to the denominator to improve numerical stability. when eps is None, use atan2 rather
        than epsilon and division for parameter updates.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        beta3: Optional[float] = None,
        d0: float = 1e-6,
        d_coef: float = 1.0,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        safeguard_warmup: bool = False,
        eps: Optional[float] = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas((*betas, beta3))
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'beta3': beta3,
            'd': d0,
            'd0': d0,
            'd_max': d0,
            'd_coef': d_coef,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'safeguard_warmup': safeguard_warmup,
            'step': 1,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Prodigy'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(p)
                state['p0'] = p.clone()
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d_de_nom = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta3: float = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2)

        bias_correction1: float = self.debias(beta1, group['step'])
        bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
        bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0

        d, d0 = group['d'], group['d0']
        d_lr: float = d * group['lr'] / bias_correction

        if 'd_numerator' not in group:
            group['d_numerator'] = torch.tensor([0.0], device=device)
        elif group['d_numerator'].device != device:
            group['d_numerator'] = group['d_numerator'].to(device)  # pragma: no cover

        d_numerator = group['d_numerator']
        d_numerator.mul_(beta3)

        for group in self.param_groups:
            if group['step'] == 1:
                self.init_group(group)

            group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                p0, exp_avg, exp_avg_sq = state['p0'], state['exp_avg'], state['exp_avg_sq']

                d_numerator.add_(torch.dot(grad.flatten(), (p0 - p).flatten()), alpha=(d / d0) * d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=(d / d0) * (d if group['safeguard_warmup'] else d_lr))

                d_de_nom.add_(s.abs().sum())

        if d_de_nom == 0:
            return loss

        d_hat = (group['d_coef'] * d_numerator / d_de_nom).item()
        if d == group['d0']:
            d = max(d, d_hat)

        d_max = max(group['d_max'], d_hat)
        d = min(d_max, d * group['growth_rate'])

        for group in self.param_groups:
            group['step'] += 1

            group['d_numerator'] = d_numerator
            group['d_de_nom'] = d_de_nom
            group['d'] = d
            group['d_max'] = d_max
            group['d_hat'] = d_hat

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                self.apply_weight_decay(
                    p,
                    p.grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                de_nom = exp_avg_sq.sqrt()

                if group['eps'] is not None:
                    de_nom.add_(d * group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-d_lr)
                else:
                    update = exp_avg.clone().atan2_(de_nom)
                    p.add_(update, alpha=-d_lr)

        return loss

Kron

Bases: BaseOptimizer

PSGD with the Kronecker product pre-conditioner.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
pre_conditioner_update_probability Optional[Tuple[Callable, float]]

Optional[Tuple[Callable, float]]. Probability of updating the pre-conditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.

None
max_size_triangular int

int. max size for dim's pre-conditioner to be triangular.

8192
min_ndim_triangular int

int. minimum number of dimensions a layer needs to have triangular pre-conditioners.

2
memory_save_mode Optional[MEMORY_SAVE_MODE_TYPE]

Optional[str]. None, 'one_diag', or 'all_diag', None is default to set all pre-conditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and 'all_diag' sets all pre-conditioners to be diagonal.

None
momentum_into_precondition_update bool

bool. whether to send momentum into pre-conditioner update instead of raw gradients.

True
mu_dtype Optional[dtype]

Optional[torch.dtype]. dtype of the momentum accumulator.

None
precondition_dtype Optional[dtype]

torch.dtype. dtype of the pre-conditioner.

float32
balance_prob float

float. probability of performing balancing.

0.01
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/psgd.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class Kron(BaseOptimizer):
    """PSGD with the Kronecker product pre-conditioner.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param pre_conditioner_update_probability: Optional[Tuple[Callable, float]]. Probability of updating the
        pre-conditioner. If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
    :param max_size_triangular: int. max size for dim's pre-conditioner to be triangular.
    :param min_ndim_triangular: int. minimum number of dimensions a layer needs to have triangular pre-conditioners.
    :param memory_save_mode: Optional[str]. None, 'one_diag', or 'all_diag', None is default to set all
        pre-conditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and
        'all_diag' sets all pre-conditioners to be diagonal.
    :param momentum_into_precondition_update: bool. whether to send momentum into pre-conditioner update instead of
        raw gradients.
    :param mu_dtype: Optional[torch.dtype]. dtype of the momentum accumulator.
    :param precondition_dtype: torch.dtype. dtype of the pre-conditioner.
    :param balance_prob: float. probability of performing balancing.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        pre_conditioner_update_probability: Optional[Tuple[Callable, float]] = None,
        max_size_triangular: int = 8192,
        min_ndim_triangular: int = 2,
        memory_save_mode: Optional[MEMORY_SAVE_MODE_TYPE] = None,
        momentum_into_precondition_update: bool = True,
        mu_dtype: Optional[torch.dtype] = None,
        precondition_dtype: Optional[torch.dtype] = torch.float32,
        balance_prob: float = 0.01,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        if pre_conditioner_update_probability is None:
            pre_conditioner_update_probability = precondition_update_prob_schedule()

        self.balance_prob: float = balance_prob
        self.eps: float = torch.finfo(torch.bfloat16).tiny
        self.prob_step: int = 0
        self.update_counter: int = 0
        self.maximize = maximize

        defaults = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'pre_conditioner_update_probability': pre_conditioner_update_probability,
            'max_size_triangular': max_size_triangular,
            'min_ndim_triangular': min_ndim_triangular,
            'memory_save_mode': memory_save_mode,
            'momentum_into_precondition_update': momentum_into_precondition_update,
            'precondition_lr': 1e-1,
            'precondition_init_scale': 1.0,
            'mu_dtype': mu_dtype,
            'precondition_dtype': precondition_dtype,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Kron'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        update_prob: Union[float, Callable] = self.param_groups[0]['pre_conditioner_update_probability']
        if callable(update_prob):
            update_prob = update_prob(self.prob_step)

        self.update_counter += 1
        do_update: bool = self.update_counter >= 1 / update_prob
        if do_update:
            self.update_counter = 0
        self.prob_step += 1

        balance: bool = np.random.random() < self.balance_prob and do_update

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            bias_correction1: float = self.debias(group['momentum'], group['step'])

            mu_dtype, precondition_dtype = group['mu_dtype'], group['precondition_dtype']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['momentum_buffer'] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
                    state['Q'], state['expressions'] = initialize_q_expressions(
                        p,
                        group['precondition_init_scale'],
                        group['max_size_triangular'],
                        group['min_ndim_triangular'],
                        group['memory_save_mode'],
                        dtype=precondition_dtype,
                    )

                momentum_buffer = state['momentum_buffer']
                momentum_buffer.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])

                if mu_dtype is not None:
                    momentum_buffer = momentum_buffer.to(dtype=mu_dtype, non_blocking=True)

                de_biased_momentum = (momentum_buffer / bias_correction1).to(
                    dtype=precondition_dtype, non_blocking=True
                )

                if grad.dim() > 1 and balance:
                    balance_q(state['Q'])

                if do_update:
                    update_precondition(
                        state['Q'],
                        state['expressions'],
                        torch.randn_like(de_biased_momentum, dtype=precondition_dtype),
                        de_biased_momentum if group['momentum_into_precondition_update'] else grad,
                        group['precondition_lr'],
                        self.eps,
                    )

                precondition_grad = get_precondition_grad(state['Q'], state['expressions'], de_biased_momentum).to(
                    dtype=p.dtype, non_blocking=True
                )

                precondition_grad.mul_(torch.clamp(1.1 / (precondition_grad.square().mean().sqrt() + 1e-6), max=1.0))

                if group['weight_decay'] != 0 and p.dim() >= 2:
                    precondition_grad.add_(p, alpha=group['weight_decay'])

                p.add_(precondition_grad, alpha=-group['lr'])

        return loss

QHAdam

Bases: BaseOptimizer

Quasi-hyperbolic momentum and Adam for deep learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
nus Tuple[float, float]

Tuple[float, float]. immediate discount factors used to estimate the gradient and its square.

(1.0, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/qhadam.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class QHAdam(BaseOptimizer):
    r"""Quasi-hyperbolic momentum and Adam for deep learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param nus: Tuple[float, float]. immediate discount factors used to estimate the gradient and its square.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        nus: Tuple[float, float] = (1.0, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_nus(nus)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'nus': nus,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['beta1_weight'] = torch.zeros((1,), dtype=torch.float32, device=grad.device)
                state['beta2_weight'] = torch.zeros((1,), dtype=torch.float32, device=grad.device)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']
            nu1, nu2 = group['nus']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                beta1_weight, beta2_weight = state['beta1_weight'], state['beta2_weight']
                beta1_weight.mul_(beta1).add_(1.0)
                beta2_weight.mul_(beta2).add_(1.0)

                beta1_adj = 1.0 - (1.0 / beta1_weight)
                beta2_adj = 1.0 - (1.0 / beta2_weight)

                grad_p2 = grad.pow(2)

                exp_avg.mul_(beta1_adj).add_((1.0 - beta1_adj) * grad)
                exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj * grad_p2)

                avg_grad = exp_avg.mul(nu1)
                if nu1 != 1.0:
                    avg_grad.add_(grad, alpha=1.0 - nu1)

                avg_grad_rms = exp_avg_sq.mul(nu2)
                if nu2 != 1.0:
                    avg_grad_rms.add_(grad_p2, alpha=1.0 - nu2)

                avg_grad_rms.sqrt_().add_(group['eps'])

                p.addcdiv_(avg_grad, avg_grad_rms, value=-group['lr'])

        return loss

QHM

Bases: BaseOptimizer

Quasi-hyperbolic momentum (QHM) optimization algorithm.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
nu float

float. immediate discount factor used to estimate the gradient and its square.

1.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/qhm.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class QHM(BaseOptimizer):
    r"""Quasi-hyperbolic momentum (QHM) optimization algorithm.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param nu: float. immediate discount factor used to estimate the gradient and its square.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        nu: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_nus(nu)

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'nu': nu,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                buf = state['momentum_buffer']

                p, grad, buf = self.view_as_real(p, grad, buf)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])

                p.add_(buf, alpha=-group['lr'] * group['nu'])
                p.add_(grad, alpha=-group['lr'] * (1.0 - group['nu']))

        return loss

RACS

Bases: BaseOptimizer

Row and Column Scaled SGD.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
beta float

float. momentum factor.

0.9
alpha float

float. scaler.

0.05
gamma float

float. limiter threshold.

1.01
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/racs.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class RACS(BaseOptimizer):
    r"""Row and Column Scaled SGD.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. momentum factor.
    :param alpha: float. scaler.
    :param gamma: float. limiter threshold.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        beta: float = 0.9,
        alpha: float = 0.05,
        gamma: float = 1.01,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_positive(gamma, 'gamma')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'beta': beta,
            'alpha': alpha,
            'gamma': gamma,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'RACS'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta = group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                state = self.state[p]

                if grad.ndim < 2:
                    grad = grad.reshape(len(grad), 1)
                elif grad.ndim > 2:
                    grad = grad.reshape(len(grad), -1)

                if len(state) == 0:
                    state['s'] = torch.zeros(grad.size(0), dtype=grad.dtype, device=grad.device)
                    state['q'] = torch.ones(grad.size(1), dtype=grad.dtype, device=grad.device)
                    state['theta'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s, q = state['s'], state['q']

                grad_p2 = grad.pow(2)
                s.mul_(beta).add_(grad_p2.mean(dim=1), alpha=1.0 - beta)
                q.mul_(beta).add_(grad_p2.mean(dim=0), alpha=1.0 - beta)

                s_sq = s.add(group['eps']).sqrt_().unsqueeze(1)
                q_sq = q.add(group['eps']).sqrt_().unsqueeze(0)

                grad_hat = grad / (s_sq * q_sq)

                grad_hat_norm = torch.norm(grad_hat)
                threshold = (
                    group['gamma'] / max(grad_hat_norm / (state['theta'] + group['eps']), group['gamma'])
                    if group['step'] > 1
                    else 1.0
                )
                state['theta'] = grad_hat_norm.mul_(threshold)

                p.add_(grad_hat.view_as(p), alpha=-group['lr'] * group['alpha'] * threshold)

        return loss

RAdam

Bases: BaseOptimizer

Rectified Adam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

float. degenerated to SGD.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/radam.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class RAdam(BaseOptimizer):
    r"""Rectified Adam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: float. degenerated to SGD.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'RAdam'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                if step_size > 0 or n_sma >= self.n_sma_threshold:
                    self.apply_weight_decay(
                        p=p,
                        grad=None,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

Ranger

Bases: BaseOptimizer

a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

False
use_gc bool

bool. use Gradient Centralization (both convolution & fc layers).

True
gc_conv_only bool

bool. use Gradient Centralization (only convolution layer).

False
eps float

float. term added to the denominator to improve numerical stability.

1e-05
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ranger.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class Ranger(BaseOptimizer):
    r"""a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param use_gc: bool. use Gradient Centralization (both convolution & fc layers).
    :param gc_conv_only: bool. use Gradient Centralization (only convolution layer).
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.95, 0.999),
        alpha: float = 0.5,
        k: int = 6,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        use_gc: bool = True,
        gc_conv_only: bool = False,
        eps: float = 1e-5,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(alpha, 'alpha', 0.0, 1.0, range_type='[]')
        self.validate_positive(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.use_gc = use_gc
        self.gc_gradient_threshold: int = 3 if gc_conv_only else 1
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'k': k,
            'step_counter': 0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Ranger'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['slow_buffer'] = p.clone()

                if group.get('adanorm'):
                    state['exp_grad_adanorm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                exp_avg, exp_avg_sq, slow_buffer = state['exp_avg'], state['exp_avg_sq'], state['slow_buffer']

                p, grad, exp_avg, exp_avg_sq, slow_buffer = self.view_as_real(
                    p, grad, exp_avg, exp_avg_sq, slow_buffer
                )

                if self.use_gc and grad.dim() > self.gc_gradient_threshold:
                    centralize_gradient(grad, gc_conv_only=False)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group.get('adanorm', False),
                    exp_grad_norm=state.get('exp_grad_adanorm', None),
                    r=group.get('adanorm_r', None),
                )

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                else:
                    p.add_(exp_avg, alpha=-step_size)

                if group['step'] % group['k'] == 0:
                    slow_buffer.lerp_(p, weight=group['alpha'])
                    p.copy_(slow_buffer)

        return loss

Ranger21

Bases: BaseOptimizer

Integrating the latest deep learning components into a single optimizer.

Here's the components
    * uses the AdamW optimizer as its core (or, optionally, MadGrad)
    * Adaptive gradient clipping
    * Gradient centralization
    * Positive-Negative momentum
    * Norm loss
    * Stable weight decay
    * Linear learning rate warm-up
    * Explore-exploit learning rate schedule
    * Lookahead
    * Softplus transformation
    * Gradient Normalization
    * Corrects the denominator (AdamD).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
num_iterations int

int. number of the total training steps. Ranger21 optimizer schedules the learning rate with its own recipes.

required
lr float

float. learning rate.

0.001
beta0 float

float. Manages the amplitude of the noise introduced by positive negative momentum While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.

0.9
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
use_softplus bool

bool. use softplus to smooth.

True
beta_softplus float

float. beta.

50.0
disable_lr_scheduler bool

bool. whether to disable learning rate schedule.

False
num_warm_up_iterations Optional[int]

Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate warmup.

None
num_warm_down_iterations Optional[int]

Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit learning rate scheduling.

None
agc_clipping_value float

float.

0.01
agc_eps float

float. eps for AGC

0.001
centralize_gradients bool

bool. use GC both convolution & fc layers.

True
normalize_gradients bool

bool. use gradient normalization.

True
lookahead_merge_time int

int. merge time.

5
lookahead_blending_alpha float

float. blending alpha.

0.5
weight_decay float

float. weight decay (L2 penalty).

0.0001
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
norm_loss_factor float

float. norm loss factor.

0.0001
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/ranger21.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class Ranger21(BaseOptimizer):
    r"""Integrating the latest deep learning components into a single optimizer.

        Here's the components
            * uses the AdamW optimizer as its core (or, optionally, MadGrad)
            * Adaptive gradient clipping
            * Gradient centralization
            * Positive-Negative momentum
            * Norm loss
            * Stable weight decay
            * Linear learning rate warm-up
            * Explore-exploit learning rate schedule
            * Lookahead
            * Softplus transformation
            * Gradient Normalization
            * Corrects the denominator (AdamD).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param num_iterations: int. number of the total training steps. Ranger21 optimizer schedules the learning rate
        with its own recipes.
    :param lr: float. learning rate.
    :param beta0: float. Manages the amplitude of the noise introduced by positive negative momentum
        While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param use_softplus: bool. use softplus to smooth.
    :param beta_softplus: float. beta.
    :param disable_lr_scheduler: bool. whether to disable learning rate schedule.
    :param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate
        warmup.
    :param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit
        learning rate scheduling.
    :param agc_clipping_value: float.
    :param agc_eps: float. eps for AGC
    :param centralize_gradients: bool. use GC both convolution & fc layers.
    :param normalize_gradients: bool. use gradient normalization.
    :param lookahead_merge_time: int. merge time.
    :param lookahead_blending_alpha: float. blending alpha.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param norm_loss_factor: float. norm loss factor.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(  # pylint: disable=R0913
        self,
        params: PARAMETERS,
        num_iterations: int,
        lr: float = 1e-3,
        beta0: float = 0.9,
        betas: BETAS = (0.9, 0.999),
        use_softplus: bool = True,
        beta_softplus: float = 50.0,
        disable_lr_scheduler: bool = False,
        num_warm_up_iterations: Optional[int] = None,
        num_warm_down_iterations: Optional[int] = None,
        warm_down_min_lr: float = 3e-5,
        agc_clipping_value: float = 1e-2,
        agc_eps: float = 1e-3,
        centralize_gradients: bool = True,
        normalize_gradients: bool = True,
        lookahead_merge_time: int = 5,
        lookahead_blending_alpha: float = 0.5,
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        norm_loss_factor: float = 1e-4,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(warm_down_min_lr)
        self.validate_betas(betas)
        self.validate_range(beta0, 'beta0', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(agc_clipping_value, 'agc_clipping_value')
        self.validate_non_negative(eps, 'eps')
        self.validate_non_negative(agc_eps, 'agc_eps')

        self.min_lr = warm_down_min_lr
        self.use_softplus = use_softplus
        self.beta_softplus = beta_softplus
        self.disable_lr_scheduler = disable_lr_scheduler
        self.agc_clipping_value = agc_clipping_value
        self.agc_eps = agc_eps
        self.centralize_gradients = centralize_gradients
        self.normalize_gradients = normalize_gradients
        self.lookahead_merge_time = lookahead_merge_time
        self.lookahead_blending_alpha = lookahead_blending_alpha
        self.norm_loss_factor = norm_loss_factor
        self.maximize = maximize

        self.lookahead_step: int = 0
        self.starting_lr: float = lr
        self.current_lr: float = lr

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

        self.num_warm_up_iterations: int = (
            self.build_warm_up_iterations(num_iterations, betas[1])
            if num_warm_up_iterations is None
            else num_warm_up_iterations
        )
        self.num_warm_down_iterations: int = (
            self.build_warm_down_iterations(num_iterations)
            if num_warm_down_iterations is None
            else num_warm_down_iterations
        )
        self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
        self.warm_down_lr_delta: float = self.starting_lr - self.min_lr

    def __str__(self) -> str:
        return 'Ranger21'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['grad_ma'] = torch.zeros_like(p)
                state['variance_ma'] = torch.zeros_like(p)
                state['lookahead_params'] = p.clone()
                state['neg_grad_ma'] = torch.zeros_like(p)
                state['max_variance_ma'] = torch.zeros_like(p)

    @staticmethod
    def build_warm_up_iterations(total_iterations: int, beta2: float, warm_up_pct: float = 0.22) -> int:
        warm_up_iterations: int = math.ceil(2.0 / (1.0 - beta2))  # default un-tuned linear warmup
        beta_pct: float = warm_up_iterations / total_iterations
        return int(warm_up_pct * total_iterations) if beta_pct > 0.45 else warm_up_iterations

    @staticmethod
    def build_warm_down_iterations(total_iterations: int, warm_down_pct: float = 0.72) -> int:
        start_warm_down: int = int(warm_down_pct * total_iterations)
        return total_iterations - start_warm_down

    def warm_up_dampening(self, lr: float, step: int) -> float:
        if step > self.num_warm_up_iterations:
            return lr

        warm_up_current_pct: float = min(1.0, (step / self.num_warm_up_iterations))

        self.current_lr = lr * warm_up_current_pct

        return self.current_lr

    def warm_down(self, lr: float, iteration: int) -> float:
        if iteration < self.start_warm_down:
            return lr

        # start iteration from 1, not 0
        warm_down_iteration: int = max((iteration + 1) - self.start_warm_down, 1)
        warm_down_pct: float = min(warm_down_iteration / (self.num_warm_down_iterations + 1), 1.0)

        self.current_lr = max(self.starting_lr - self.warm_down_lr_delta * warm_down_pct, self.min_lr)

        return self.current_lr

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        variance_ma_sum: float = 1.0

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                param_size += p.numel()

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad.copy_(agc(p, grad, self.agc_eps, self.agc_clipping_value))

                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                variance_ma = state['variance_ma']
                variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                variance_ma_sum += (variance_ma / bias_correction2).sum()

        if param_size == 0:
            raise ZeroParameterSizeError()

        variance_normalized = math.sqrt(variance_ma_sum / param_size)

        for group in self.param_groups:
            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            if self.disable_lr_scheduler:
                lr: float = group['lr']
            else:
                lr: float = self.warm_up_dampening(group['lr'], group['step'])
                lr = self.warm_down(lr, group['step'])

            step_size: float = self.apply_adam_debias(group.get('adam_debias', False), lr, bias_correction1)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / variance_normalized,
                )

                correction = 2.0 * self.norm_loss_factor * (1.0 - 1.0 / unit_norm(p).add_(group['eps']))
                p.mul_(1.0 - lr * correction)

                state = self.state[p]
                if group['step'] % 2 == 1:
                    grad_ma, neg_grad_ma = state['grad_ma'], state['neg_grad_ma']
                else:
                    grad_ma, neg_grad_ma = state['neg_grad_ma'], state['grad_ma']

                variance_ma = state['variance_ma']
                torch.max(state['max_variance_ma'], variance_ma, out=variance_ma)

                de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps'])

                if self.use_softplus:
                    de_nom = softplus(de_nom, beta=self.beta_softplus)

                grad = p.grad
                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                grad_ma.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2)  # fmt: skip

                pn_momentum = grad_ma.mul(2.0).add_(neg_grad_ma, alpha=-1.0).mul_(1.0 / noise_norm)
                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        self.lookahead_process_step()

        return loss

    def lookahead_process_step(self):
        self.lookahead_step += 1
        if self.lookahead_step >= self.lookahead_merge_time:
            self.lookahead_step: int = 0
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    state = self.state[p]

                    p.mul_(self.lookahead_blending_alpha).add_(
                        state['lookahead_params'],
                        alpha=1.0 - self.lookahead_blending_alpha,
                    )
                    state['lookahead_params'].copy_(p)

RotoGrad

Bases: RotateOnly

Implementation of RotoGrad as described in the original paper.

Parameters:

Name Type Description Default
backbone Module

nn.Module. shared module.

required
heads Sequence[Module]

List[nn.Module]. task-specific modules.

required
latent_size int

int. size of the shared representation, size of the output of the backbone.z.

required
burn_in_period int

int. When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, :math:{L_k(t)}/{L_k(t_0 = 0)}. This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, :math:t_0 = burn\_in\_period. This is done to overcome problems with losses quickly changing in the first iterations.

20
normalize_losses bool

bool. Whether to use this normalized losses to back-propagate through the task-specific parameters as well.

False
Source code in pytorch_optimizer/optimizer/rotograd.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class RotoGrad(RotateOnly):
    r"""Implementation of RotoGrad as described in the original paper.

    :param backbone: nn.Module. shared module.
    :param heads: List[nn.Module]. task-specific modules.
    :param latent_size: int. size of the shared representation, size of the output of the backbone.z.
    :param burn_in_period: int. When back-propagating towards the shared parameters, *each task loss is normalized
        dividing by its initial value*, :math:`{L_k(t)}/{L_k(t_0 = 0)}`. This parameter sets a number of iterations
        after which the denominator will be replaced by the value of the loss at that iteration, that is,
        :math:`t_0 = burn\_in\_period`. This is done to overcome problems with losses quickly changing
        in the first iterations.
    :param normalize_losses: bool. Whether to use this normalized losses to back-propagate through the task-specific
        parameters as well.
    """

    num_tasks: int
    backbone: nn.Module
    heads: Sequence[nn.Module]
    rep: torch.Tensor

    def __init__(
        self,
        backbone: nn.Module,
        heads: Sequence[nn.Module],
        latent_size: int,
        *args,
        burn_in_period: int = 20,
        normalize_losses: bool = False,
    ):
        super().__init__(backbone, heads, latent_size, burn_in_period, *args, normalize_losses=normalize_losses)

        self.initial_grads = None
        self.counter: int = 0

    def _rep_grad(self):
        super()._rep_grad()

        grad_norms = [torch.linalg.norm(g, keepdim=True).clamp_min(1e-15) for g in self.original_grads]
        if self.initial_grads is None or self.counter == self.burn_in_period:
            self.initial_grads = grad_norms
            conv_ratios = [torch.ones((1,)) for _ in range(len(self.initial_grads))]
        else:
            conv_ratios = [x / y for x, y in zip(grad_norms, self.initial_grads)]

        self.counter += 1

        alphas = [x / torch.clamp(sum(conv_ratios), 1e-15) for x in conv_ratios]
        weighted_sum_norms = sum(a * g for a, g in zip(alphas, grad_norms))

        return sum(g / n * weighted_sum_norms for g, n in zip(self.original_grads, grad_norms))

SAM

Bases: BaseOptimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization.

Example:

Here's an example::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = SAM(model.parameters(), base_optimizer)

    for input, output in data:
        # first forward-backward pass

        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward pass
        # make sure to do a full forward pass
        loss_function(output, model(input)).backward()
        optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = SAM(model.parameters(), base_optimizer)

    def closure():
        loss = loss_function(output, model(input))
        loss.backward()
        return loss

    for input, output in data:
        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.step(closure)
        optimizer.zero_grad()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

OPTIMIZER. base optimizer.

required
rho float

float. size of the neighborhood for computing the max loss.

0.05
adaptive bool

bool. element-wise Adaptive SAM.

False
use_gc bool

bool. perform gradient centralization, GCSAM variant.

False
perturb_eps float

float. eps for perturbation.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class SAM(BaseOptimizer):
    r"""Sharpness-Aware Minimization for Efficiently Improving Generalization.

    Example:
    -------
        Here's an example::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = SAM(model.parameters(), base_optimizer)

            for input, output in data:
                # first forward-backward pass

                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                # make sure to do a full forward pass
                loss_function(output, model(input)).backward()
                optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = SAM(model.parameters(), base_optimizer)

            def closure():
                loss = loss_function(output, model(input))
                loss.backward()
                return loss

            for input, output in data:
                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.step(closure)
                optimizer.zero_grad()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: OPTIMIZER. base optimizer.
    :param rho: float. size of the neighborhood for computing the max loss.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param use_gc: bool. perform gradient centralization, GCSAM variant.
    :param perturb_eps: float. eps for perturbation.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        rho: float = 0.05,
        adaptive: bool = False,
        use_gc: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')
        self.validate_non_negative(perturb_eps, 'perturb_eps')

        self.use_gc = use_gc
        self.perturb_eps = perturb_eps

        defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive, **kwargs}

        super().__init__(params, defaults)

        self.base_optimizer: Optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'SAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device).add_(self.perturb_eps)

        for group in self.param_groups:
            scale = group['rho'] / grad_norm

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                self.state[p]['old_p'] = p.clone()

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * grad * scale.to(p)

                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                p.data = self.state[p]['old_p']

        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        with torch.enable_grad():
            closure()

        self.second_step()

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

ScheduleFreeSGD

Bases: BaseOptimizer

Schedule-Free SGD.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
momentum float

float. momentum factor, must be between 0 and 1 exclusive.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
r float

float. use polynomial weighting in the average with power r.

0.0
weight_lr_power float

float. during warmup, the weights in the average will be equal to lr raised to this power. set to 0 for no weighting.

2.0
warmup_steps int

int. enables a linear learning rate warmup.

0
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class ScheduleFreeSGD(BaseOptimizer):
    r"""Schedule-Free SGD.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor, must be between 0 and 1 exclusive.
    :param weight_decay: float. weight decay (L2 penalty).
    :param r: float. use polynomial weighting in the average with power r.
    :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
        set to 0 for no weighting.
    :param warmup_steps: int. enables a linear learning rate warmup.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        warmup_steps: int = 0,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'warmup_steps': warmup_steps,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'ScheduleFreeSGD'

    def eval(self):
        for group in self.param_groups:
            momentum = group['momentum']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / momentum)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            momentum = group['momentum']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - momentum)
                group['train_mode'] = True

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            warmup_steps: int = group['warmup_steps']
            schedule: float = group['step'] / warmup_steps if group['step'] < warmup_steps else 1.0

            momentum = group['momentum']

            lr: float = group['lr'] * schedule
            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z = state['z']

                p, grad, z = self.view_as_real(p, grad, z)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=lr * (momentum * (1.0 - checkpoint) - 1))

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeAdamW

Bases: BaseOptimizer

Schedule-Free AdamW.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0025
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
r float

float. use polynomial weighting in the average with power r.

0.0
weight_lr_power float

float. during warmup, the weights in the average will be equal to lr raised to this power. set to 0 for no weighting.

2.0
warmup_steps int

int. enables a linear learning rate warmup.

0
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
class ScheduleFreeAdamW(BaseOptimizer):
    r"""Schedule-Free AdamW.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param r: float. use polynomial weighting in the average with power r.
    :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
        set to 0 for no weighting.
    :param warmup_steps: int. enables a linear learning rate warmup.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 2.5e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        warmup_steps: int = 0,
        ams_bound: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'warmup_steps': warmup_steps,
            'ams_bound': ams_bound,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
            'use_palm': kwargs.get('use_palm', False),
        }

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'ScheduleFreeAdamW'

    def eval(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - beta1)
                group['train_mode'] = True

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            warmup_steps: int = group['warmup_steps']
            schedule: float = group['step'] / warmup_steps if group['step'] < warmup_steps else 1.0

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            lr: float = group['lr'] * schedule
            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, exp_avg_sq = state['z'], state['exp_avg_sq']

                p, grad, z, exp_avg_sq = self.view_as_real(p, grad, z, exp_avg_sq)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq.div(bias_correction2),
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=lr * (beta1 * (1.0 - checkpoint) - 1))

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeRAdam

Bases: BaseOptimizer

Schedule-Free RAdam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0025
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
r float

float. use polynomial weighting in the average with power r.

0.0
weight_lr_power float

float. during warmup, the weights in the average will be equal to lr raised to this power. set to 0 for no weighting.

2.0
silent_sgd_phase bool

bool. the optimizer will not use the first SGD phase of RAdam. This means that the optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup behavior and more reliable calculation of the moving average coefficient (ckp1). Recommended to set to True.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
class ScheduleFreeRAdam(BaseOptimizer):
    r"""Schedule-Free RAdam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param r: float. use polynomial weighting in the average with power r.
    :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
        set to 0 for no weighting.
    :param silent_sgd_phase: bool. the optimizer will not use the first SGD phase of RAdam. This means that the
        optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but
        just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup
        behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 2.5e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        silent_sgd_phase: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'silent_sgd_phase': silent_sgd_phase,
            'r': r,
            'weight_lr_power': weight_lr_power,
            'eps': eps,
            'train_mode': True,
            'weight_sum': 0.0,
            'lr_max': -1.0,
            'use_palm': kwargs.get('use_palm', False),
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ScheduleFreeRAdam'

    def eval(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - 1.0 / beta1)
                group['train_mode'] = False

    def train(self):
        for group in self.param_groups:
            beta1, _ = group['betas']
            if not group['train_mode']:
                for p in group['params']:
                    state = self.state[p]
                    if 'z' in state:
                        p.data.lerp_(end=state['z'], weight=1.0 - beta1)
                group['train_mode'] = True

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['z'] = p.clone()
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction2: float = self.debias(beta2, group['step'])

            lr, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=4,
                degenerated_to_sgd=False,
            )
            if lr < 0.0:
                lr = float(not group['silent_sgd_phase'])

            lr_max = group['lr_max'] = max(lr, group['lr_max'])

            weight: float = (group['step'] ** group['r']) * (lr_max ** group['weight_lr_power'])
            weight_sum = group['weight_sum'] = group['weight_sum'] + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            adaptive_y_lr: float = lr * (beta1 * (1.0 - checkpoint) - 1.0)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z, exp_avg_sq = state['z'], state['exp_avg_sq']

                p, grad, z, exp_avg_sq = self.view_as_real(p, grad, z, exp_avg_sq)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma > 4.0:
                    de_nom = exp_avg_sq.sqrt().div_(bias_correction2).add_(group['eps'])
                    grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                p.lerp_(z, weight=checkpoint)
                p.add_(grad, alpha=adaptive_y_lr)

                z.sub_(grad, alpha=lr)

        return loss

ScheduleFreeWrapper

Bases: BaseOptimizer

Wrap any optimizer to make it Schedule-Free.

This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases
the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free
needs to be directly integrated with the base optimizer.

When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using
our wrapper's momentum (although you can use both types of momentum if you want).

If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute
weight decay at $y$, via the `weight_decay_at_y` parameter, which seems to give better results in our
experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current
learning rate.

Parameters:

Name Type Description Default
optimizer OPTIMIZER_INSTANCE_OR_CLASS

OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.

required
momentum float

float. momentum.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
r float

float. use polynomial weighting in the average with power r.

0.0
weight_lr_power float

float. during warmup, the weights in the average will be equal to lr raised to this power. set to 0 for no weighting.

2.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/schedulefree.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
class ScheduleFreeWrapper(BaseOptimizer):
    r"""Wrap any optimizer to make it Schedule-Free.

        This version uses a memory-efficient swap operation but may be slower than the reference version. In most cases
        the performance difference is negligible. For the best possible performance and memory-usage, Schedule-Free
        needs to be directly integrated with the base optimizer.

        When using this version, you can disable the base optimizer's momentum, as it's no longer necessary when using
        our wrapper's momentum (although you can use both types of momentum if you want).

        If you set weight decay on the base optimizer, it computes weight decay at $z$. We offer the option to compute
        weight decay at $y$, via the `weight_decay_at_y` parameter, which seems to give better results in our
        experiments. This approach to decay only works correctly if the base optimizer uses group['lr'] as the current
        learning rate.

    :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.
    :param momentum: float. momentum.
    :param weight_decay: float. weight decay (L2 penalty).
    :param r: float. use polynomial weighting in the average with power r.
    :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
        set to 0 for no weighting.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER_INSTANCE_OR_CLASS,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        r: float = 0.0,
        weight_lr_power: float = 2.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_range(momentum, 'momentum', 0.0, 1.0, '[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.momentum = momentum
        self.weight_decay = weight_decay
        self.r = r
        self.weight_lr_power = weight_lr_power
        self.train_mode: bool = False
        self.maximize = maximize

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.state: STATE = defaultdict(dict)
        self.defaults: DEFAULTS = self.optimizer.defaults

    def __str__(self) -> str:
        return 'ScheduleFree'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def __getstate__(self):
        return {'state': self.state, 'optimizer': self.optimizer}

    def add_param_group(self, param_group):
        return self.optimizer.add_param_group(param_group)

    def state_dict(self) -> STATE:
        return {'schedulefree_state': self.state, 'base_optimizer': self.optimizer.state_dict()}

    def load_state_dict(self, state: STATE) -> None:
        r"""Load state."""
        self.state = state['schedulefree_state']
        self.optimizer.load_state_dict(state['base_optimizer'])

    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none)

    @torch.no_grad()
    def eval(self):
        if not self.train_mode:
            return

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'z' in state:
                    p.lerp_(end=state['z'], weight=1.0 - 1.0 / self.momentum)

        self.train_mode = False

    @torch.no_grad()
    def train(self):
        if self.train_mode:
            return

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'z' in state:
                    p.lerp_(end=state['z'], weight=1.0 - self.momentum)

        self.train_mode = True

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if 'z' not in state:
                state['z'] = p.clone()

    @staticmethod
    def swap(x: torch.Tensor, y: torch.Tensor) -> None:
        x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))
        y.view(torch.uint8).bitwise_xor_(x.view(torch.uint8))
        x.view(torch.uint8).bitwise_xor_(y.view(torch.uint8))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        if not self.train_mode:
            raise ValueError('optimizer was not in train mode when step is called. call .train() before training')

        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                z = state['z']

                self.apply_weight_decay(
                    z,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                )

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=self.weight_decay,
                    weight_decouple=True,
                    fixed_decay=False,
                    ratio=1.0 - self.momentum,
                )

                p.lerp_(end=z, weight=1.0 - 1.0 / self.momentum)

                self.swap(z, p)

        self.optimizer.step()

        for group in self.param_groups:
            lr: float = group['lr'] * group.get('d', 1.0)
            lr_max = group['lr_max'] = max(lr, group.get('lr_max', 0))

            weight: float = (group['step'] ** group['lr']) * (lr_max ** self.weight_lr_power)  # fmt: skip
            weight_sum = group['weight_sum'] = group.get('weight_sum', 0.0) + weight

            checkpoint: float = weight / weight_sum if weight_sum != 0.0 else 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                z = state['z']

                self.swap(z, p)

                p.lerp_(end=z, weight=checkpoint)

                p.lerp_(end=state['z'], weight=1.0 - self.momentum)

        return loss

load_state_dict(state)

Load state.

Source code in pytorch_optimizer/optimizer/schedulefree.py
568
569
570
571
def load_state_dict(self, state: STATE) -> None:
    r"""Load state."""
    self.state = state['schedulefree_state']
    self.optimizer.load_state_dict(state['base_optimizer'])

SCION

Bases: BaseOptimizer

Training Deep Learning Models with Norm-Constrained LMOs.

Example: >>> radius = 50.0 >>> parameter_groups = [{ ... 'params': model.transformer.h.parameters(), ... 'norm_type': 'spectral', ... 'norm_kwargs': {}, ... 'scale': radius, ... }, { ... 'params': model.lm_head.parameters(), ... 'norm_type': 'sign', ... 'norm_kwargs': {}, ... 'scale': radius * 60.0, ... }] >>> optimizer = SCION(parameter_groups)

For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor. 1.0 - usual momentum.

0.1
constraint bool

bool. whether to use a constraint SCG or not.

False
norm_type int

int. supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7. please check LMONorm Enum class for the details.

AUTO
norm_kwargs Optional[Dict]

Optional[Dict]. arguments for the Norm.

None
scale float

float. based on the usage of the original intend, 50.0 is used for Transformer block, and 3000.0 is used for others (e.g. Embedding, LM head)

1.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/scion.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class SCION(BaseOptimizer):
    r"""Training Deep Learning Models with Norm-Constrained LMOs.

    Example:
        >>> radius = 50.0
        >>> parameter_groups = [{
        ...     'params': model.transformer.h.parameters(),
        ...     'norm_type': 'spectral',
        ...     'norm_kwargs': {},
        ...     'scale': radius,
        ... }, {
        ...     'params': model.lm_head.parameters(),
        ...     'norm_type': 'sign',
        ...     'norm_kwargs': {},
        ...     'scale': radius * 60.0,
        ... }]
        >>> optimizer = SCION(parameter_groups)

        For more details, checkout here https://github.com/LIONS-EPFL/scion/tree/main?tab=readme-ov-file#examples

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor. 1.0 - usual momentum.
    :param constraint: bool. whether to use a constraint SCG or not.
    :param norm_type: int. supported LMO norm types. 0 stands for no normalization and 1 stands for AUTO. 0 to 7.
        please check LMONorm Enum class for the details.
    :param norm_kwargs: Optional[Dict]. arguments for the Norm.
    :param scale: float. based on the usage of the original intend, 50.0 is used for Transformer block, and 3000.0 is
        used for others (e.g. Embedding, LM head)
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.1,
        constraint: bool = False,
        norm_type: int = LMONorm.AUTO,
        norm_kwargs: Optional[Dict] = None,
        scale: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, '(]')
        self.validate_positive(scale, 'scale')

        self.maximize = maximize

        if norm_kwargs is None:
            norm_kwargs = {}

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'constraint': constraint,
            'norm_type': norm_type,
            'norm_kwargs': norm_kwargs,
            'scale': scale,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SCION'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if 'd' not in state:
                state['d'] = torch.zeros_like(grad)

    @torch.no_grad()
    def init(self):
        for group in self.param_groups:
            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])
            for p in group['params']:
                norm.init(p)
                p.mul_(group['scale'])

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            norm = build_lmo_norm(group['norm_type'], **group['norm_kwargs'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                d = state['d']

                d.mul_(1.0 - group['momentum']).add_(grad, alpha=group['momentum'])

                update = norm.lmo(d).mul_(group['scale'])

                if group['constraint']:
                    p.mul_(1.0 - group['lr'])

                if not group['constraint'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=False,
                    )

                p.add_(update, alpha=-group['lr'])

        return loss

StableAdamW

Bases: BaseOptimizer

Stable and low-precision training for large-scale vision-language models.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
kahan_sum bool

bool. Enables Kahan summation for more accurate parameter updates when training in low precision (float16 or bfloat16).

True
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. decoupled weight decay.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamw.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class StableAdamW(BaseOptimizer):
    r"""Stable and low-precision training for large-scale vision-language models.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param kahan_sum: bool. Enables Kahan summation for more accurate parameter updates when training in low precision
        (float16 or bfloat16).
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.99),
        kahan_sum: bool = True,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'kahan_sum': kahan_sum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'StableAdamW'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

                state['kahan_comp'] = (
                    torch.zeros_like(p)
                    if (group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16})
                    else None
                )

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            beta1_comp: float = 1.0 - self.debias_beta(beta1, group['step'])
            beta2_hat: float = self.debias_beta(beta2, group['step'])

            eps_p2: float = math.pow(group['eps'], 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                p, grad, exp_avg, exp_avg_sq = self.view_as_real(p, grad, exp_avg, exp_avg_sq)

                exp_avg.lerp_(grad, weight=beta1_comp)
                exp_avg_sq.mul_(beta2_hat).addcmul_(grad, grad, value=1.0 - beta2_hat)

                lr: float = group['lr'] / self.get_stable_adamw_rms(grad, exp_avg_sq, eps=eps_p2)

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                if group['kahan_sum'] and p.dtype in {torch.float16, torch.bfloat16}:
                    kahan_comp = state['kahan_comp']
                    kahan_comp.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

                    grad.copy_(p.detach())
                    p.add_(kahan_comp)

                    kahan_comp.add_(grad.sub_(p))
                else:
                    p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-lr)

        return loss

AccSGD

Bases: BaseOptimizer

Accelerating Stochastic Gradient Descent For Least Squares Regression.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
kappa float

float. ratio of long to short step.

1000.0
xi float

float. statistical advantage parameter.

10.0
constant float

float. any small constant under 1.

0.7
weight_decay float

float. weight decay.

0.0
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class AccSGD(BaseOptimizer):
    r"""Accelerating Stochastic Gradient Descent For Least Squares Regression.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param kappa: float. ratio of long to short step.
    :param xi: float. statistical advantage parameter.
    :param constant: float. any small constant under 1.
    :param weight_decay: float. weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        kappa: float = 1000.0,
        xi: float = 10.0,
        constant: float = 0.7,
        weight_decay: float = 0.0,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(kappa, 'kappa')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_boundary(constant, boundary=1.0, bound_type='upper')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'kappa': kappa,
            'xi': xi,
            'constant': constant,
            'weight_decay': weight_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AccSGD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            large_lr: float = group['lr'] * group['kappa'] / group['constant']
            alpha: float = 1.0 - (group['xi'] * (group['constant'] ** 2) / group['kappa'])
            beta: float = 1.0 - alpha
            zeta: float = group['constant'] / (group['constant'] + beta)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                buf = state['momentum_buffer']
                buf.mul_((1.0 / beta) - 1.0).add_(grad, alpha=-large_lr).add_(p).mul_(beta)

                p.add_(grad, alpha=-group['lr']).mul_(zeta).add_(buf, alpha=1.0 - zeta)

        return loss

SGDW

Bases: BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
momentum float

float. momentum factor.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
dampening float

float. dampening for momentum.

0.0
nesterov bool

bool. enables Nesterov momentum.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class SGDW(BaseOptimizer):
    r"""Decoupled Weight Decay Regularization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param dampening: float. dampening for momentum.
    :param nesterov: bool. enables Nesterov momentum.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        dampening: float = 0.0,
        nesterov: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'dampening': dampening,
            'nesterov': nesterov,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDW'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                    if group['nesterov']:
                        grad.add_(buf, alpha=momentum)
                    else:
                        grad = buf

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(grad, alpha=-group['lr'])

        return loss

ASGD

Bases: BaseOptimizer

Adaptive SGD with estimation of the local smoothness (curvature).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
amplifier float

float. amplifier.

0.02
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
theta float

float. theta.

1.0
dampening float

float. dampening for momentum.

1.0
eps float

float. term added to the denominator to improve numerical stability.

1e-05
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
class ASGD(BaseOptimizer):
    r"""Adaptive SGD with estimation of the local smoothness (curvature).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param amplifier: float. amplifier.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param theta: float. theta.
    :param dampening: float. dampening for momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        amplifier: float = 0.02,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        theta: float = 1.0,
        dampening: float = 1.0,
        eps: float = 1e-5,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(amplifier, 'amplifier')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'amplifier': amplifier,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'theta': theta,
            'dampening': dampening,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ASGD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @staticmethod
    def get_norms_by_group(group: GROUP, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Get parameter & gradient norm by group."""
        p_norm = torch.zeros(1, dtype=torch.float32, device=device)
        g_norm = torch.zeros(1, dtype=torch.float32, device=device)

        for p in group['params']:
            if p.grad is None:
                continue

            p_norm.add_(p.norm().pow(2))
            g_norm.add_(p.grad.norm().pow(2))

        p_norm.sqrt_()
        g_norm.sqrt_()

        return p_norm, g_norm

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            device = group['params'][0].device

            if 'prev_param_norm' not in group and 'prev_grad_norm' not in group:
                group['prev_param_norm'], group['prev_grad_norm'] = self.get_norms_by_group(group, device)

            group['curr_param_norm'], group['curr_grad_norm'] = self.get_norms_by_group(group, device)

            param_diff_norm: float = (group['curr_param_norm'] - group['prev_param_norm']).item()
            grad_diff_norm: float = (group['curr_grad_norm'] - group['prev_grad_norm']).item()

            new_lr: float = group['lr'] * math.sqrt(1 + group['amplifier'] * group['theta'])
            if param_diff_norm > 0 and grad_diff_norm > 0:
                new_lr = min(new_lr, param_diff_norm / (group['dampening'] * grad_diff_norm)) + group['eps']

            group['theta'] = new_lr / group['lr']
            group['lr'] = new_lr

            group['prev_param_norm'].copy_(group['curr_param_norm'])
            group['prev_grad_norm'].copy_(group['curr_grad_norm'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(grad, alpha=-new_lr)

        return loss

get_norms_by_group(group, device) staticmethod

Get parameter & gradient norm by group.

Source code in pytorch_optimizer/optimizer/sgd.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
@staticmethod
def get_norms_by_group(group: GROUP, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Get parameter & gradient norm by group."""
    p_norm = torch.zeros(1, dtype=torch.float32, device=device)
    g_norm = torch.zeros(1, dtype=torch.float32, device=device)

    for p in group['params']:
        if p.grad is None:
            continue

        p_norm.add_(p.norm().pow(2))
        g_norm.add_(p.grad.norm().pow(2))

    p_norm.sqrt_()
    g_norm.sqrt_()

    return p_norm, g_norm

SignSGD

Bases: BaseOptimizer

Compressed Optimisation for Non-Convex Problems.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor (0.0 = SignSGD, >0 = Signum).

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class SignSGD(BaseOptimizer):
    r"""Compressed Optimisation for Non-Convex Problems.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum).
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'beta', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SignSGD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if group['momentum'] > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
                else:
                    buf = grad

                p.add_(torch.sign(buf) if not torch.is_complex(buf) else torch.sgn(buf), alpha=-group['lr'])

        return loss

SGDSaI

Bases: BaseOptimizer

No More Adam: Learning Rate Scaling at Initialization is All You Need.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
momentum float

float. coefficients used for computing running averages of gradient.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
class SGDSaI(BaseOptimizer):
    r"""No More Adam: Learning Rate Scaling at Initialization is All You Need.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. coefficients used for computing running averages of gradient.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        momentum: float = 0.9,
        weight_decay: float = 1e-2,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.has_warmup: bool = False
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDSaI'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if group['momentum'] > 0.0:
                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def warmup_step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                sigma = grad.std().nan_to_num_() if grad.ndim > 1 and grad.size(0) != 1 else 0
                grad_norm = grad.norm()

                g_snr = grad_norm.div_(sigma.add_(group['eps'])) if sigma != 0.0 else grad_norm

                self.state[p]['gsnr'] = g_snr

        self.has_warmup = True

        return loss

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        if not self.has_warmup:
            self.warmup_step(closure)

        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            group['step'] += 1

            momentum: float = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
                else:
                    buf = grad

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(buf, alpha=-group['lr'] * state['gsnr'])

        return loss

SGDP

Bases: BaseOptimizer

SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
dampening float

float. dampening for momentum.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
delta float

float. threshold that determines whether a set of parameters is scale invariant or not.

0.1
wd_ratio float

float. relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
nesterov bool

bool. enables nesterov momentum.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/adamp.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class SGDP(BaseOptimizer):
    r"""SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param dampening: float. dampening for momentum.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param delta: float. threshold that determines whether a set of parameters is scale invariant or not.
    :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied on
        scale-variant parameters.
    :param nesterov: bool. enables nesterov momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        nesterov: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'momentum': momentum,
            'dampening': dampening,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDP'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]
            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                buf = state['momentum']

                p, grad, buf = self.view_as_real(p, grad, buf)

                buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                d_p = buf.clone()
                if group['nesterov']:
                    d_p = d_p.mul_(momentum).add_(grad)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    d_p, wd_ratio = projection(
                        p,
                        grad,
                        d_p,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio / (1.0 - momentum),
                )

                p.add_(d_p, alpha=-group['lr'])

        return loss

Shampoo

Bases: BaseOptimizer

Preconditioned Stochastic Tensor Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
preconditioning_compute_steps int

int. performance tuning params for controlling memory and compute requirements. How often to compute pre-conditioner.

1
matrix_eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/shampoo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class Shampoo(BaseOptimizer):
    r"""Preconditioned Stochastic Tensor Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
        requirements. How often to compute pre-conditioner.
    :param matrix_eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        preconditioning_compute_steps: int = 1,
        matrix_eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.preconditioning_compute_steps = preconditioning_compute_steps
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'matrix_eps': matrix_eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Shampoo'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                if group['momentum'] > 0.0:
                    state['momentum_buffer'] = grad.clone()

                for dim_id, dim in enumerate(grad.size()):
                    state[f'pre_cond_{dim_id}'] = group['matrix_eps'] * torch.eye(dim, out=grad.new(dim, dim))
                    state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if momentum > 0.0:
                    grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                order: int = grad.ndimension()
                original_size: int = grad.size()
                for dim_id, dim in enumerate(grad.size()):
                    pre_cond, inv_pre_cond = state[f'pre_cond_{dim_id}'], state[f'inv_pre_cond_{dim_id}']

                    grad = grad.transpose_(0, dim_id).contiguous()
                    transposed_size = grad.size()

                    grad = grad.view(dim, -1)
                    grad_t = grad.t()

                    pre_cond.add_(grad @ grad_t)
                    if group['step'] % self.preconditioning_compute_steps == 0:
                        inv_pre_cond.copy_(compute_power_svd(pre_cond, order))

                    if dim_id == order - 1:
                        grad = grad_t @ inv_pre_cond
                        grad = grad.view(original_size)
                    else:
                        grad = inv_pre_cond @ grad
                        grad = grad.view(transposed_size)

                state['momentum_buffer'] = grad

                p.add_(grad, alpha=-group['lr'])

        return loss

ScalableShampoo

Bases: BaseOptimizer

Scalable Preconditioned Stochastic Tensor Optimization.

This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.

Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
computing G^{-1/p} matrices while the previous one uses SVD which is really slow.

Also, this implementation offers
    1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
    2. not-yet implemented features in the official Pytorch code.
    3. readable, organized, clean code.

Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. beta1, beta2.

(0.9, 0.999)
moving_average_for_momentum bool

bool. perform moving_average for momentum (beta1).

False
weight_decay float

float. weight decay (L2 penalty).

0.0
decoupled_weight_decay bool

bool. use decoupled weight_decay.

False
decoupled_learning_rate bool

bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.

True
inverse_exponent_override int

int. fixed exponent for pre-conditioner, if > 0.

0
start_preconditioning_step int

int.

25
preconditioning_compute_steps int

int. performance tuning params for controlling memory and compute requirements. How often to compute pre-conditioner. Ideally, 1 is the best. However, the current implementation doesn't work on the distributed environment (there are no statistics & pre-conditioners sync among replicas), compute on the GPU (not CPU) and the precision is fp32 (not fp64). Also, followed by the paper, preconditioning_compute_steps does not have a significant effect on the performance. So, If you have a problem with the speed, try to set this step bigger (e.g. 1000).

1000
statistics_compute_steps int

int. How often to compute statistics. usually set to 1 (or 10).

1
block_size int

int. Block size for large layers (if > 0). Block size = 1 ==> AdaGrad (Don't do this, extremely inefficient!) Block size should be as large as feasible under memory/time constraints.

512
skip_preconditioning_rank_lt int

int. Skips preconditioning for parameters with rank less than this value.

1
no_preconditioning_for_layers_with_dim_gt int

int. avoid preconditioning large layers to reduce overall memory.

8192
shape_interpretation bool

bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].

True
graft_type int

int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).

SGD
pre_conditioner_type int

int. type of pre-conditioner.

ALL
nesterov bool

bool. Nesterov momentum.

True
diagonal_eps float

float. term added to the denominator to improve numerical stability.

1e-10
matrix_eps float

float. term added to the denominator to improve numerical stability.

1e-06
use_svd bool

bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}. Theoretically, Schur-Newton method is faster than SVD method. However, the inefficiency of the loop code and proper svd kernel, SVD is much faster in some cases (usually in case of small models). see https://github.com/kozistr/pytorch_optimizer/pull/103

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/shampoo.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class ScalableShampoo(BaseOptimizer):
    r"""Scalable Preconditioned Stochastic Tensor Optimization.

        This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
        or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
        CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.

        Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
        computing G^{-1/p} matrices while the previous one uses SVD which is really slow.

        Also, this implementation offers
            1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
            2. not-yet implemented features in the official Pytorch code.
            3. readable, organized, clean code.

        Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. beta1, beta2.
    :param moving_average_for_momentum: bool. perform moving_average for momentum (beta1).
    :param weight_decay: float. weight decay (L2 penalty).
    :param decoupled_weight_decay: bool. use decoupled weight_decay.
    :param decoupled_learning_rate: bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.
    :param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
    :param start_preconditioning_step: int.
    :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
        requirements. How often to compute pre-conditioner. Ideally, 1 is the best. However, the current implementation
        doesn't work on the distributed environment (there are no statistics & pre-conditioners sync among replicas),
        compute on the GPU (not CPU) and the precision is fp32 (not fp64).
        Also, followed by the paper, `preconditioning_compute_steps` does not have a significant effect on the
        performance. So, If you have a problem with the speed, try to set this step bigger (e.g. 1000).
    :param statistics_compute_steps: int. How often to compute statistics. usually set to 1 (or 10).
    :param block_size: int. Block size for large layers (if > 0).
        Block size = 1 ==> AdaGrad (Don't do this, extremely inefficient!)
        Block size should be as large as feasible under memory/time constraints.
    :param skip_preconditioning_rank_lt: int. Skips preconditioning for parameters with rank less than this value.
    :param no_preconditioning_for_layers_with_dim_gt: int. avoid preconditioning large layers to reduce overall memory.
    :param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
        result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
        statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
    :param graft_type: int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).
    :param pre_conditioner_type: int. type of pre-conditioner.
    :param nesterov: bool. Nesterov momentum.
    :param diagonal_eps: float. term added to the denominator to improve numerical stability.
    :param matrix_eps: float. term added to the denominator to improve numerical stability.
    :param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
        Theoretically, Schur-Newton method is faster than SVD method. However, the inefficiency of the loop code and
        proper svd kernel, SVD is much faster in some cases (usually in case of small models).
        see https://github.com/kozistr/pytorch_optimizer/pull/103
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        moving_average_for_momentum: bool = False,
        weight_decay: float = 0.0,
        decoupled_weight_decay: bool = False,
        decoupled_learning_rate: bool = True,
        inverse_exponent_override: int = 0,
        start_preconditioning_step: int = 25,
        preconditioning_compute_steps: int = 1000,
        statistics_compute_steps: int = 1,
        block_size: int = 512,
        skip_preconditioning_rank_lt: int = 1,
        no_preconditioning_for_layers_with_dim_gt: int = 8192,
        shape_interpretation: bool = True,
        graft_type: int = LayerWiseGrafting.SGD,
        pre_conditioner_type: int = PreConditionerType.ALL,
        nesterov: bool = True,
        diagonal_eps: float = 1e-10,
        matrix_eps: float = 1e-6,
        use_svd: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(start_preconditioning_step, 'start_preconditioning_step')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_step(statistics_compute_steps, 'statistics_compute_steps')
        self.validate_non_negative(diagonal_eps, 'diagonal_eps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.inverse_exponent_override = inverse_exponent_override
        self.start_preconditioning_step = start_preconditioning_step
        self.preconditioning_compute_steps = preconditioning_compute_steps
        self.statistics_compute_steps = statistics_compute_steps
        self.block_size = block_size
        self.skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
        self.no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
        self.shape_interpretation = shape_interpretation
        self.graft_type = graft_type
        self.pre_conditioner_type = pre_conditioner_type
        self.diagonal_eps = diagonal_eps
        self.matrix_eps = matrix_eps
        self.use_svd = use_svd
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'decoupled_weight_decay': decoupled_weight_decay,
            'decoupled_learning_rate': decoupled_learning_rate,
            'moving_average_for_momentum': moving_average_for_momentum,
            'nesterov': nesterov,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ScalableShampoo'

    def init_group(self, group: GROUP, **kwargs) -> None:
        _, beta2 = group['betas']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)
                state['pre_conditioner'] = PreConditioner(
                    p,
                    beta2,
                    self.inverse_exponent_override,
                    self.block_size,
                    self.skip_preconditioning_rank_lt,
                    self.no_preconditioning_for_layers_with_dim_gt,
                    self.shape_interpretation,
                    self.pre_conditioner_type,
                    self.matrix_eps,
                    self.use_svd,
                )
                state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)

    def is_precondition_step(self, step: int) -> bool:
        return step >= self.start_preconditioning_step

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            is_precondition_step: bool = self.is_precondition_step(group['step'])
            pre_conditioner_multiplier: float = 1.0 if group['decoupled_learning_rate'] else group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                pre_conditioner, graft = state['pre_conditioner'], state['graft']

                graft.add_statistics(grad, beta2)
                if group['step'] % self.statistics_compute_steps == 0:
                    pre_conditioner.add_statistics(grad)
                if group['step'] % self.preconditioning_compute_steps == 0:
                    pre_conditioner.compute_pre_conditioners()

                graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
                shampoo_grad: torch.Tensor = (
                    pre_conditioner.preconditioned_grad(grad) if is_precondition_step else grad
                )

                if self.graft_type != LayerWiseGrafting.NONE:
                    graft_norm = torch.linalg.norm(graft_grad)
                    shampoo_norm = torch.linalg.norm(shampoo_grad)

                    shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

                for g in (graft_grad, shampoo_grad):
                    self.apply_weight_decay(
                        p,
                        grad=g,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['decoupled_weight_decay'],
                        fixed_decay=False,
                    )

                state['momentum'].mul_(beta1).add_(shampoo_grad)
                graft_momentum = graft.update_momentum(grad, beta1)

                momentum_update = state['momentum'] if is_precondition_step else graft_momentum

                if group['nesterov']:
                    w: float = (1.0 - beta1) if group['moving_average_for_momentum'] else 1.0

                    wd_update = shampoo_grad if is_precondition_step else graft_grad
                    wd_update.mul_(w)

                    momentum_update.mul_(beta1).add_(wd_update)

                p.add_(momentum_update, alpha=-group['lr'])

        return loss

SM3

Bases: BaseOptimizer

Memory-Efficient Adaptive Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
momentum float

float. coefficient used to scale prior updates before adding. This drastically increases memory usage if momentum > 0.0. This is ignored if the parameter's gradient is sparse.

0.0
beta float

float. coefficient used for exponential moving averages.

0.0
eps float

float. term added to the denominator to improve numerical stability.

1e-30
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sm3.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class SM3(BaseOptimizer):
    r"""Memory-Efficient Adaptive Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. coefficient used to scale prior updates before adding. This drastically increases
        memory usage if `momentum > 0.0`. This is ignored if the parameter's gradient is sparse.
    :param beta: float. coefficient used for exponential moving averages.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        momentum: float = 0.0,
        beta: float = 0.0,
        eps: float = 1e-30,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'momentum': momentum, 'beta': beta, 'eps': eps}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SM3'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            grad = p.grad

            shape = grad.shape
            rank: int = len(shape)

            state = self.state[p]

            if len(state) == 0:
                state['momentum_buffer'] = torch.zeros_like(grad)

                if grad.is_sparse:
                    state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
                elif rank == 0:
                    state['accumulator_0'] = torch.zeros_like(grad)
                else:
                    for i in range(rank):
                        state[f'accumulator_{i}'] = torch.zeros(
                            [1] * i + [shape[i]] + [1] * (rank - 1 - i), dtype=grad.dtype, device=grad.device
                        )

    @staticmethod
    def make_sparse(grad: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        if grad._indices().dim() == 0 or values.dim() == 0:
            return grad.new().resize_as_(grad)
        return grad.new(grad._indices(), values, grad.size())

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum, beta = group['momentum'], group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                shape = grad.shape
                rank: int = len(shape)

                state = self.state[p]

                if grad.is_sparse:
                    grad = grad.coalesce()

                    acc = state['accumulator_0']
                    update_values = torch.gather(acc, 0, grad._indices()[0])
                    if beta > 0.0:
                        update_values.mul_(beta)
                    update_values.addcmul_(grad._values(), grad._values(), value=1.0 - beta)

                    nu_max = reduce_max_except_dim(self.make_sparse(grad, update_values).to_dense(), 0).squeeze_()

                    if beta > 0.0:
                        torch.max(acc, nu_max, out=acc)
                    else:
                        acc.copy_(nu_max)

                    update_values.add_(group['eps']).rsqrt_().mul_(grad._values())

                    update = self.make_sparse(grad, update_values)
                else:
                    update = state['accumulator_0'].clone()
                    for i in range(1, rank):
                        update = torch.min(update, state[f'accumulator_{i}'])

                    if beta > 0.0:
                        update.mul_(beta)
                    update.addcmul_(grad, grad, value=1.0 - beta)

                    for i in range(rank):
                        acc = state[f'accumulator_{i}']
                        nu_max = reduce_max_except_dim(update, i)
                        if beta > 0.0:
                            torch.max(acc, nu_max, out=acc)
                        else:
                            acc.copy_(nu_max)

                    update.add_(group['eps']).rsqrt_().mul_(grad)

                    if momentum > 0.0:
                        m = state['momentum_buffer']
                        m.mul_(momentum).add_(update, alpha=1.0 - momentum)
                        update = m

                p.add_(update, alpha=-group['lr'])

        return loss

SOAP

Bases: BaseOptimizer

Improving and Stabilizing Shampoo using Adam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.003
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace

(0.95, 0.95)
shampoo_beta Optional[float]

Optional[float]. if not None, use this beta for the pre-conditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].

None
weight_decay float

float. weight decay (L2 penalty).

0.01
precondition_frequency int

int. how often to update the pre-conditioner.

10
max_precondition_dim int

int. maximum dimension of the pre-conditioner. Set to 10000, so that we exclude most common vocab sizes while including layers.

10000
merge_dims bool

bool. whether to merge dimensions of the pre-conditioner

False
precondition_1d bool

bool. whether to precondition 1D gradients.

False
correct_bias bool

bool. whether to correct bias in Adam.

True
normalize_gradient bool

bool. whether to normalize the gradients.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/soap.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class SOAP(BaseOptimizer):
    r"""Improving and Stabilizing Shampoo using Adam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
    :param shampoo_beta: Optional[float]. if not None, use this beta for the pre-conditioner (L and R in paper,
        state['GG'] below) moving average instead of betas[1].
    :param weight_decay: float. weight decay (L2 penalty).
    :param precondition_frequency: int. how often to update the pre-conditioner.
    :param max_precondition_dim: int. maximum dimension of the pre-conditioner. Set to 10000, so that we exclude most
        common vocab sizes while including layers.
    :param merge_dims: bool. whether to merge dimensions of the pre-conditioner
    :param precondition_1d: bool. whether to precondition 1D gradients.
    :param correct_bias: bool. whether to correct bias in Adam.
    :param normalize_gradient: bool. whether to normalize the gradients.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 3e-3,
        betas: BETAS = (0.95, 0.95),
        shampoo_beta: Optional[float] = None,
        weight_decay: float = 1e-2,
        precondition_frequency: int = 10,
        max_precondition_dim: int = 10000,
        merge_dims: bool = False,
        precondition_1d: bool = False,
        correct_bias: bool = True,
        normalize_gradient: bool = False,
        data_format: DATA_FORMAT = 'channels_first',
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(shampoo_beta, 'shampoo_beta')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(precondition_frequency, 'precondition_frequency')
        self.validate_positive(max_precondition_dim, 'max_precondition_dim')
        self.validate_non_negative(eps, 'eps')

        self.data_format = data_format
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'shampoo_beta': shampoo_beta,
            'weight_decay': weight_decay,
            'precondition_frequency': precondition_frequency,
            'max_precondition_dim': max_precondition_dim,
            'merge_dims': merge_dims,
            'precondition_1d': precondition_1d,
            'correct_bias': correct_bias,
            'normalize_gradient': normalize_gradient,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SOAP'

    def init_group(self, group: GROUP, **kwargs) -> None:
        _, beta2 = group['betas']

        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)

                self.init_pre_conditioner(
                    grad,
                    state,
                    precondition_frequency=group['precondition_frequency'],
                    shampoo_beta=group['shampoo_beta'] if group['shampoo_beta'] is not None else beta2,
                    max_precondition_dim=group['max_precondition_dim'],
                    precondition_1d=group['precondition_1d'],
                    merge_dims=group['merge_dims'],
                )

                self.update_pre_conditioner(
                    grad,
                    state,
                    step=group['step'],
                    max_precondition_dim=group['max_precondition_dim'],
                    precondition_1d=group['precondition_1d'],
                    merge_dims=group['merge_dims'],
                )

    def project(
        self,
        grad: torch.Tensor,
        state,
        merge_dims: bool = False,
        max_precondition_dim: int = 10000,
        project_type: str = 'forward',
    ) -> torch.Tensor:
        original_shape = grad.shape

        if merge_dims:
            if self.data_format == 'channels_last' and grad.dim() == 4:
                permuted_shape = grad.permute(0, 3, 1, 2).shape

            grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

        for mat in state['Q']:
            if len(mat) > 0:
                grad = torch.tensordot(grad, mat, dims=[[0], [0 if project_type == 'forward' else 1]])
            else:
                grad = grad.permute([*list(range(1, len(grad.shape))), 0])

        if merge_dims:
            if self.data_format == 'channels_last' and len(original_shape) == 4:
                grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                grad = grad.reshape(original_shape)

        return grad

    @staticmethod
    def get_orthogonal_matrix(mat: torch.Tensor) -> List[torch.Tensor]:
        matrices: List = []
        for m in mat:
            if len(m) == 0:
                matrices.append([])
                continue

            try:
                _, q = torch.linalg.eigh(m + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=m.dtype))
            except Exception:  # pragma: no cover
                _, q = torch.linalg.eigh(
                    m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device, dtype=torch.float64)
                )
                q = q.to(m.dtype)

            q = torch.flip(q, dims=[1])

            matrices.append(q)

        return matrices

    def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, merge_dims: bool = False):
        r"""Compute the eigen-bases of the pre-conditioner using one round of power iteration."""
        orig_shape = state['exp_avg_sq'].shape
        if self.data_format == 'channels_last' and len(orig_shape) == 4:
            permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape

        exp_avg_sq = state['exp_avg_sq']
        if merge_dims:
            exp_avg_sq = exp_avg_sq.reshape(merge_small_dims(exp_avg_sq.size(), max_precondition_dim))

        matrices = []
        for ind, (m, o) in enumerate(zip(state['GG'], state['Q'])):
            if len(m) == 0:
                matrices.append([])
                continue

            est_eig = torch.diag(o.T @ m @ o)
            sort_idx = torch.argsort(est_eig, descending=True)
            exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)

            power_iter = m @ o[:, sort_idx]

            # Compute QR decomposition
            # We cast to float32 because:
            #  - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
            #  - the correctness / numerical stability of the Q orthogonality is important for the stability
            #    of the optimizer
            q, _ = torch.linalg.qr(power_iter.to(torch.float32))
            q = q.to(power_iter.dtype)

            matrices.append(q)

        if merge_dims:
            if self.data_format == 'channels_last' and len(orig_shape) == 4:
                exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
            else:
                exp_avg_sq = exp_avg_sq.reshape(orig_shape)

        state['exp_avg_sq'] = exp_avg_sq

        return matrices

    @staticmethod
    def init_pre_conditioner(
        grad,
        state,
        precondition_frequency: int = 10,
        shampoo_beta: float = 0.95,
        max_precondition_dim: int = 10000,
        precondition_1d: bool = False,
        merge_dims: bool = False,
    ) -> None:
        state['GG'] = []
        if grad.dim() == 1:
            if not precondition_1d or grad.shape[0] > max_precondition_dim:
                state['GG'].append([])
            else:
                state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device, dtype=grad.dtype))
        else:
            if merge_dims:
                grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

            for sh in grad.shape:
                if sh > max_precondition_dim:
                    state['GG'].append([])
                else:
                    state['GG'].append(torch.zeros(sh, sh, device=grad.device, dtype=grad.dtype))

        state['Q'] = None
        state['precondition_frequency'] = precondition_frequency
        state['shampoo_beta'] = shampoo_beta

    def update_pre_conditioner(
        self,
        grad,
        state,
        step: int,
        max_precondition_dim: int = 10000,
        precondition_1d: bool = False,
        merge_dims: bool = False,
    ) -> None:
        if grad.dim() == 1:
            if precondition_1d and grad.shape[0] <= max_precondition_dim:
                state['GG'][0].lerp_(
                    (grad.unsqueeze(1) @ grad.unsqueeze(0)).to(state['GG'][0].dtype),
                    weight=1.0 - state['shampoo_beta'],
                )
        else:
            if merge_dims:
                grad = grad.reshape(merge_small_dims(grad.size(), max_precondition_dim))

            for idx, dim in enumerate(grad.shape):
                if dim <= max_precondition_dim:
                    outer_product = torch.tensordot(
                        grad,
                        grad,
                        dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                    )

                    state['GG'][idx].lerp_(
                        outer_product.to(state['GG'][idx].dtype), weight=1.0 - state['shampoo_beta']
                    )

        if state['Q'] is None:
            state['Q'] = self.get_orthogonal_matrix(state['GG'])

        if step > 0 and step % state['precondition_frequency'] == 0:
            state['Q'] = self.get_orthogonal_matrix_qr(state, max_precondition_dim, merge_dims)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                group['step'] = 1
                self.init_group(group)
                continue

            group['step'] += 1

            beta1, beta2 = group['betas']

            step_size: float = group['lr']
            if group['correct_bias']:
                bias_correction1: float = self.debias(beta1, group['step'])
                bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

                step_size *= bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                grad_projected = self.project(
                    grad, state, merge_dims=group['merge_dims'], max_precondition_dim=group['max_precondition_dim']
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                exp_avg_projected = self.project(
                    exp_avg, state, merge_dims=group['merge_dims'], max_precondition_dim=group['max_precondition_dim']
                )

                norm_grad = self.project(
                    exp_avg_projected / de_nom,
                    state,
                    merge_dims=group['merge_dims'],
                    max_precondition_dim=group['max_precondition_dim'],
                    project_type='backward',
                )

                if group['normalize_gradient']:
                    norm_grad.div_(torch.mean(norm_grad.square()).sqrt_().add_(group['eps']))

                p.add_(norm_grad, alpha=-step_size)

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                self.update_pre_conditioner(
                    grad,
                    state,
                    step=group['step'],
                    max_precondition_dim=group['max_precondition_dim'],
                    merge_dims=group['merge_dims'],
                    precondition_1d=group['precondition_1d'],
                )

        return loss

get_orthogonal_matrix_qr(state, max_precondition_dim=10000, merge_dims=False)

Compute the eigen-bases of the pre-conditioner using one round of power iteration.

Source code in pytorch_optimizer/optimizer/soap.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def get_orthogonal_matrix_qr(self, state, max_precondition_dim: int = 10000, merge_dims: bool = False):
    r"""Compute the eigen-bases of the pre-conditioner using one round of power iteration."""
    orig_shape = state['exp_avg_sq'].shape
    if self.data_format == 'channels_last' and len(orig_shape) == 4:
        permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape

    exp_avg_sq = state['exp_avg_sq']
    if merge_dims:
        exp_avg_sq = exp_avg_sq.reshape(merge_small_dims(exp_avg_sq.size(), max_precondition_dim))

    matrices = []
    for ind, (m, o) in enumerate(zip(state['GG'], state['Q'])):
        if len(m) == 0:
            matrices.append([])
            continue

        est_eig = torch.diag(o.T @ m @ o)
        sort_idx = torch.argsort(est_eig, descending=True)
        exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)

        power_iter = m @ o[:, sort_idx]

        # Compute QR decomposition
        # We cast to float32 because:
        #  - torch.linalg.qr does not have support for types like bfloat16 as of PyTorch 2.5.1
        #  - the correctness / numerical stability of the Q orthogonality is important for the stability
        #    of the optimizer
        q, _ = torch.linalg.qr(power_iter.to(torch.float32))
        q = q.to(power_iter.dtype)

        matrices.append(q)

    if merge_dims:
        if self.data_format == 'channels_last' and len(orig_shape) == 4:
            exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
        else:
            exp_avg_sq = exp_avg_sq.reshape(orig_shape)

    state['exp_avg_sq'] = exp_avg_sq

    return matrices

SophiaH

Bases: BaseOptimizer

Second-order Clipped Stochastic Optimization.

Requires `loss.backward(create_graph=True)` in order to calculate hessians.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.06
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.96, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
p float

float. clip effective (applied) gradient (p).

0.01
update_period int

int. number of steps after which to apply hessian approximation.

10
num_samples int

int. times to sample z for the approximation of the hessian trace.

1
hessian_distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution to initialize hessian.

'gaussian'
eps float

float. term added to the denominator to improve numerical stability.

1e-12
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sophia.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class SophiaH(BaseOptimizer):
    r"""Second-order Clipped Stochastic Optimization.

        Requires `loss.backward(create_graph=True)` in order to calculate hessians.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param p: float. clip effective (applied) gradient (p).
    :param update_period: int. number of steps after which to apply hessian approximation.
    :param num_samples: int. times to sample `z` for the approximation of the hessian trace.
    :param hessian_distribution: HUTCHINSON_G. type of distribution to initialize hessian.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 6e-2,
        betas: BETAS = (0.96, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        p: float = 1e-2,
        update_period: int = 10,
        num_samples: int = 1,
        hessian_distribution: HUTCHINSON_G = 'gaussian',
        eps: float = 1e-12,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(p, 'p (gradient clip)')
        self.validate_step(update_period, 'update_period')
        self.validate_positive(num_samples, 'num_samples')
        self.validate_options(hessian_distribution, 'hessian_distribution', ['gaussian', 'rademacher'])
        self.validate_non_negative(eps, 'eps')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution
        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'p': p,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SophiaH'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['momentum'] = torch.zeros_like(grad)
                state['hessian_moment'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                momentum, hessian_moment = state['momentum'], state['hessian_moment']
                momentum.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    hessian_moment.mul_(beta2).add_(state['hessian'], alpha=1.0 - beta2)

                update = (momentum / torch.clip(hessian_moment, min=group['eps'])).clamp_(-group['p'], group['p'])

                p.add_(update, alpha=-group['lr'])

        return loss

SPAM

Bases: BaseOptimizer

Spike-Aware Adam with Momentum Reset for Stable LLM Training.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
density float

float. density parameter. only used for 2d parameters (e.g. Linear).

1.0
weight_decay float

float. weight decay (L2 penalty).

0.0
warmup_epoch int

int: number of epochs to warm up. defaults to 50.

50
threshold int

int. threshold for gradient masking. defaults to 5000.

5000
grad_accu_steps int

int. gradient accumulation steps before threshold-based masking applies. defaults to 20.

20
update_proj_gap int

int. update projection gap.

500
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/spam.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
class SPAM(BaseOptimizer):
    r"""Spike-Aware Adam with Momentum Reset for Stable LLM Training.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param density: float. density parameter. only used for 2d parameters (e.g. Linear).
    :param weight_decay: float. weight decay (L2 penalty).
    :param warmup_epoch: int: number of epochs to warm up. defaults to 50.
    :param threshold: int. threshold for gradient masking. defaults to 5000.
    :param grad_accu_steps: int. gradient accumulation steps before threshold-based masking applies. defaults to 20.
    :param update_proj_gap: int. update projection gap.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        density: float = 1.0,
        weight_decay: float = 0.0,
        warmup_epoch: int = 50,
        threshold: int = 5000,
        grad_accu_steps: int = 20,
        update_proj_gap: int = 500,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(warmup_epoch, 'warmup_epoch')
        self.validate_non_negative(density, 'density')
        self.validate_non_negative(threshold, 'threshold')
        self.validate_non_negative(grad_accu_steps, 'grad_accu_steps')
        self.validate_positive(update_proj_gap, 'update_proj_gap')
        self.validate_non_negative(eps, 'eps')

        self.density = density
        self.warmup_epoch = warmup_epoch
        self.threshold = threshold
        self.grad_accu_steps = grad_accu_steps
        self.update_proj_gap = update_proj_gap
        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

        self.warmup = CosineDecay(0.99, self.warmup_epoch)

        self.init_masks()

        self.state['total_step'] = 0
        self.state['current_step'] = self.warmup_epoch + 1

    @staticmethod
    def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
        r"""Create an (m x n) boolean tensor with `density` fraction of True entries.

        :param m: int. number of rows.
        :param n: int. number of columns.
        :param density: float. fraction of True entries. 1.0 means all True.
        :param device: torch.device. device.
        """
        total_elements: int = m * n
        non_zero_count: int = int(density * total_elements)

        tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)

        if non_zero_count > 0:
            tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True

        return tensor.view(m, n)

    def update_mask_random(self, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
        r"""Update a random mask.

        Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
        the overlap region.

        :param p: torch.Tensor. parameter to which the mask is applied.
        :param old_mask: torch.Tensor. previous binary mask.
        """
        new_mask: torch.Tensor = torch.rand_like(p) < self.density

        exp_avg = torch.zeros_like(p[new_mask])
        exp_avg_sq = torch.zeros_like(p[new_mask])

        intersection_mask = new_mask & old_mask
        new_intersection_indices = intersection_mask[new_mask]
        old_intersection_indices = intersection_mask[old_mask]

        state = self.state[p]
        exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices]
        exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices]

        state['exp_avg'] = exp_avg
        state['exp_avg_sq'] = exp_avg_sq

        return new_mask

    def update_masks(self) -> None:
        r"""Update masks in each parameter group that has 'density'.

        The new mask is selected randomly, and the overlap ratio with the old mask is printed.
        """
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.dim() == 2 and 'mask' in state:
                    state['mask'] = self.update_mask_random(p, state['mask'])
                    p.mask = state['mask']

    def init_masks(self) -> None:
        r"""Initialize random masks for each parameter group that has 'density'."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.dim() == 2 and 'mask' not in state:
                    state['mask'] = self.initialize_random_rank_boolean_tensor(
                        m=p.shape[0],
                        n=p.shape[1],
                        density=self.density,
                        device=p.device,
                    )

    def __str__(self) -> str:
        return 'SPAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        scale_factor: float = 1.0 - self.warmup.get_death_rate(self.state['current_step'])

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if torch.is_complex(p):
                    raise NoComplexParameterError(str(self))

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                if 'mask' in state:
                    grad = grad[state['mask']]

                if ('exp_avg' not in state) or (self.state['total_step'] + 1) % self.update_proj_gap == 0:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                if self.threshold != 0:
                    current_step: int = self.state['total_step'] + 1
                    if current_step >= self.grad_accu_steps and (
                        self.update_proj_gap == 0 or current_step % self.update_proj_gap >= self.grad_accu_steps
                    ):
                        mask = grad.pow(2) > (self.threshold * exp_avg_sq)
                        grad[mask].sign_().mul_(torch.sqrt(exp_avg_sq[mask] * self.threshold))

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if 'mask' in state:
                    grad_full = torch.zeros_like(p.grad)
                    grad_full[state['mask']] = exp_avg / de_nom
                    p.add_(grad_full, alpha=-step_size * scale_factor)
                else:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size * scale_factor)

                self.apply_weight_decay(
                    p[state['mask']] if 'mask' in state else p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

        self.state['total_step'] += 1
        self.state['current_step'] += 1

        if (self.state['total_step'] != 0) and (self.state['total_step'] + 1) % self.update_proj_gap == 0:
            self.update_masks()
            self.state['current_step'] = 0
            self.warmup = CosineDecay(0.99, self.warmup_epoch)

        return loss

init_masks()

Initialize random masks for each parameter group that has 'density'.

Source code in pytorch_optimizer/optimizer/spam.py
167
168
169
170
171
172
173
174
175
176
177
178
def init_masks(self) -> None:
    r"""Initialize random masks for each parameter group that has 'density'."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            if p.dim() == 2 and 'mask' not in state:
                state['mask'] = self.initialize_random_rank_boolean_tensor(
                    m=p.shape[0],
                    n=p.shape[1],
                    density=self.density,
                    device=p.device,
                )

initialize_random_rank_boolean_tensor(m, n, density, device) staticmethod

Create an (m x n) boolean tensor with density fraction of True entries.

Parameters:

Name Type Description Default
m int

int. number of rows.

required
n int

int. number of columns.

required
density float

float. fraction of True entries. 1.0 means all True.

required
device device

torch.device. device.

required
Source code in pytorch_optimizer/optimizer/spam.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@staticmethod
def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device: torch.device) -> torch.Tensor:
    r"""Create an (m x n) boolean tensor with `density` fraction of True entries.

    :param m: int. number of rows.
    :param n: int. number of columns.
    :param density: float. fraction of True entries. 1.0 means all True.
    :param device: torch.device. device.
    """
    total_elements: int = m * n
    non_zero_count: int = int(density * total_elements)

    tensor = torch.zeros(total_elements, dtype=torch.bool, device=device)

    if non_zero_count > 0:
        tensor[torch.randperm(total_elements, device=device)[:non_zero_count]] = True

    return tensor.view(m, n)

update_mask_random(p, old_mask)

Update a random mask.

Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for the overlap region.

Parameters:

Name Type Description Default
p Tensor

torch.Tensor. parameter to which the mask is applied.

required
old_mask Tensor

torch.Tensor. previous binary mask.

required
Source code in pytorch_optimizer/optimizer/spam.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def update_mask_random(self, p: torch.Tensor, old_mask: torch.Tensor) -> torch.Tensor:
    r"""Update a random mask.

    Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
    the overlap region.

    :param p: torch.Tensor. parameter to which the mask is applied.
    :param old_mask: torch.Tensor. previous binary mask.
    """
    new_mask: torch.Tensor = torch.rand_like(p) < self.density

    exp_avg = torch.zeros_like(p[new_mask])
    exp_avg_sq = torch.zeros_like(p[new_mask])

    intersection_mask = new_mask & old_mask
    new_intersection_indices = intersection_mask[new_mask]
    old_intersection_indices = intersection_mask[old_mask]

    state = self.state[p]
    exp_avg[new_intersection_indices] = state['exp_avg'][old_intersection_indices]
    exp_avg_sq[new_intersection_indices] = state['exp_avg_sq'][old_intersection_indices]

    state['exp_avg'] = exp_avg
    state['exp_avg_sq'] = exp_avg_sq

    return new_mask

update_masks()

Update masks in each parameter group that has 'density'.

The new mask is selected randomly, and the overlap ratio with the old mask is printed.

Source code in pytorch_optimizer/optimizer/spam.py
155
156
157
158
159
160
161
162
163
164
165
def update_masks(self) -> None:
    r"""Update masks in each parameter group that has 'density'.

    The new mask is selected randomly, and the overlap ratio with the old mask is printed.
    """
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            if p.dim() == 2 and 'mask' in state:
                state['mask'] = self.update_mask_random(p, state['mask'])
                p.mask = state['mask']

StableSPAM

Bases: BaseOptimizer

How to Train in 4-Bit More Stably than 16-Bit Adam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
gamma1 float

float. gamma1 parameter.

0.7
gamma2 float

float. gamma2 parameter.

0.9
theta float

float. theta parameter.

0.999
t_max Optional[int]

Optional[int]. total number of steps.

None
eta_min float

float. eta_min of CosineDecay.

0.5
weight_decay float

float. weight decay (L2 penalty).

0.0
update_proj_gap int

int. update projection gap.

1000
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/spam.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
class StableSPAM(BaseOptimizer):
    r"""How to Train in 4-Bit More Stably than 16-Bit Adam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param gamma1: float. gamma1 parameter.
    :param gamma2: float. gamma2 parameter.
    :param theta: float. theta parameter.
    :param t_max: Optional[int]. total number of steps.
    :param eta_min: float. eta_min of CosineDecay.
    :param weight_decay: float. weight decay (L2 penalty).
    :param update_proj_gap: int. update projection gap.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        gamma1: float = 0.7,
        gamma2: float = 0.9,
        theta: float = 0.999,
        t_max: Optional[int] = None,
        eta_min: float = 0.5,
        weight_decay: float = 0.0,
        update_proj_gap: int = 1000,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_positive(update_proj_gap, 'update_proj_gap')
        self.validate_non_negative(eps, 'eps')

        self.gamma1: float = betas[0] if gamma1 == -1.0 else gamma1
        self.gamma2: float = gamma2
        self.theta: float = theta
        self.t_max = t_max
        self.update_proj_gap = update_proj_gap
        self.warmup = CosineDecay(1.0, t_max, eta_min=eta_min) if t_max is not None else None
        self.maximize = maximize

        self.total_step: int = 0

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'weight_decay': weight_decay, 'eps': eps, **kwargs}

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'StableSPAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if 'exp_avg' not in state:
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)
                state['m_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
                state['v_norm_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)
                state['m_max_t'] = torch.zeros(1, device=grad.device, dtype=grad.dtype)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.total_step += 1

        scale: float = self.warmup.get_death_rate(self.total_step) if self.warmup is not None else 1.0

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']
            beta1 *= scale

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])
            bias_correction2_sq: float = math.sqrt(bias_correction2)

            step_size: float = group['lr'] / bias_correction1

            theta_t: float = 1.0 - self.theta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=True,
                    fixed_decay=False,
                )

                max_grad = torch.max(grad.abs())

                exp_avg, exp_avg_sq, m_max_t = state['exp_avg'], state['exp_avg_sq'], state['m_max_t']

                m_max_t.lerp_(max_grad, weight=1.0 - self.theta)

                m_max_hat = m_max_t / theta_t

                mask = grad.abs() > m_max_hat
                if mask.sum() > 0:
                    grad[mask].div_(max_grad).mul_(m_max_hat)

                grad_norm = torch.norm(grad)

                m_norm_t, v_norm_t = state['m_norm_t'], state['v_norm_t']
                m_norm_t.lerp_(grad_norm, weight=1.0 - self.gamma1 * scale)
                v_norm_t.lerp_(grad_norm.pow(2), weight=1.0 - self.gamma2)

                m_norm_hat = m_norm_t / (1.0 - (self.gamma1 * scale) ** group['step'])
                v_norm_hat = v_norm_t / (1.0 - self.gamma2 ** group['step'])

                c_norm_t = m_norm_hat.div_(v_norm_hat.sqrt_().add_(group['eps']))

                grad.div_(grad_norm).mul_(c_norm_t)

                if self.update_proj_gap > 0 and self.total_step % self.update_proj_gap == 0:
                    state['exp_avg'] = torch.zeros_like(grad)
                    state['exp_avg_sq'] = torch.zeros_like(grad)
                    group['step'] = 1

                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

SRMM

Bases: BaseOptimizer

Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. adaptivity weight.

0.5
memory_length Optional[int]

Optional[int]. internal memory length for moving average. None for no refreshing.

100
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/srmm.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class SRMM(BaseOptimizer):
    """Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. adaptivity weight.
    :param memory_length: Optional[int]. internal memory length for moving average. None for no refreshing.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 0.01,
        beta: float = 0.5,
        memory_length: Optional[int] = 100,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        self.maximize = maximize

        defaults: DEFAULTS = {'lr': lr, 'beta': beta, 'memory_length': memory_length}

        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'SRMM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['mov_avg_grad'] = torch.zeros_like(grad)
                state['mov_avg_param'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            w_t: float = (
                (group['step'] % (group['memory_length'] if group['memory_length'] is not None else 1)) + 1
            ) ** -group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param']

                mov_avg_grad.mul_(1.0 - w_t).add_(grad, alpha=w_t)
                mov_avg_param.mul_(1.0 - w_t).add_(p, alpha=w_t)

                mov_avg_param.add_(mov_avg_grad, alpha=-group['lr'])

                p.copy_(mov_avg_param)

        return loss

SWATS

Bases: BaseOptimizer

Improving Generalization Performance by Switching from Adam to SGD.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the ams_bound variant of this algorithm from the paper.

False
nesterov bool

bool. enables Nesterov momentum.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/swats.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class SWATS(BaseOptimizer):
    r"""Improving Generalization Performance by Switching from Adam to SGD.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the ams_bound variant of this algorithm from the paper.
    :param nesterov: bool. enables Nesterov momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        nesterov: bool = False,
        eps: float = 1e-6,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'nesterov': nesterov,
            'phase': 'adam',
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SWATS'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            if torch.is_complex(p):
                raise NoComplexParameterError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg2'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2: float = self.debias(beta2, group['step'])

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False),
                step_size=group['lr'] * math.sqrt(bias_correction2),
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['phase'] == 'sgd':
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(grad)

                    buf = state['momentum_buffer']
                    buf.mul_(beta1).add_(grad)

                    update = buf.clone()
                    update.mul_(1.0 - beta1)

                    if group['nesterov']:
                        update.add_(buf, alpha=beta1)

                    p.add_(update, alpha=-group['lr'])

                    continue

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                perturb = exp_avg.clone()
                perturb.div_(de_nom).mul_(-step_size)

                p.add_(perturb)

                perturb_view = perturb.view(-1)
                pg = perturb_view.dot(grad.view(-1))

                if pg != 0:
                    scaling = perturb_view.dot(perturb_view).div_(-pg)

                    exp_avg2 = state['exp_avg2']
                    exp_avg2.mul_(beta2).add_(scaling, alpha=1.0 - beta2)

                    corrected_exp_avg = exp_avg2 / bias_correction2

                    if (
                        group['step'] > 1
                        and corrected_exp_avg > 0.0
                        and corrected_exp_avg.allclose(scaling, rtol=group['eps'])
                    ):
                        group['phase'] = 'sgd'
                        group['lr'] = corrected_exp_avg.item()

        return loss

TAM

Bases: BaseOptimizer

Torque-Aware Momentum.

:parma decay_rate: float. smoothing decay rate.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. coefficients used for computing running averages of gradient.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/tam.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class TAM(BaseOptimizer):
    r"""Torque-Aware Momentum.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. coefficients used for computing running averages of gradient.
    :parma decay_rate: float. smoothing decay rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        decay_rate: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(decay_rate, 'decay_rate', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'decay_rate': decay_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'TAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(grad)
                state['momentum_buffer'] = grad.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            momentum: float = group['momentum']
            decay_rate: float = group['decay_rate']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                s, momentum_buffer = state['s'], state['momentum_buffer']

                corr = normalize(momentum_buffer, p=2.0, dim=0).mul_(normalize(grad, p=2.0, dim=0))
                s.mul_(decay_rate).add_(corr, alpha=1.0 - decay_rate)

                d = ((1.0 + s) / 2.0).add_(group['eps']).mul_(grad)

                momentum_buffer.mul_(momentum).add_(d)

                self.apply_weight_decay(
                    p,
                    grad,
                    group['lr'],
                    group['weight_decay'],
                    group['weight_decouple'],
                    group['fixed_decay'],
                )

                p.add_(momentum_buffer, alpha=-group['lr'])

        return loss

AdaTAM

Bases: BaseOptimizer

Adaptive Torque-Aware Momentum.

:parma decay_rate: float. smoothing decay rate.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/tam.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
class AdaTAM(BaseOptimizer):
    r"""Adaptive Torque-Aware Momentum.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :parma decay_rate: float. smoothing decay rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        decay_rate: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(decay_rate, 'decay_rate', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'decay_rate': decay_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaTAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['s'] = torch.zeros_like(grad)
                state['exp_avg'] = torch.zeros_like(grad)
                state['exp_avg_sq'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']
            decay_rate: float = group['decay_rate']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad,
                    group['lr'],
                    group['weight_decay'],
                    group['weight_decouple'],
                    group['fixed_decay'],
                )

                s, exp_avg, exp_avg_sq = state['s'], state['exp_avg'], state['exp_avg_sq']

                corr = normalize(exp_avg, p=2.0, dim=0).mul_(normalize(grad, p=2.0, dim=0))
                s.mul_(decay_rate).add_(corr, alpha=1.0 - decay_rate)

                d = ((1.0 + s) / 2.0).add_(group['eps']).mul_(grad)

                exp_avg.mul_(beta1).add_(d)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                p.addcdiv_(exp_avg, exp_avg_sq.sqrt().add_(group['eps']), value=-group['lr'])

        return loss

Tiger

Bases: BaseOptimizer

A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
beta float

float. coefficients used for computing running averages of gradient and the squared hessian trace.

0.965
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/tiger.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class Tiger(BaseOptimizer):
    r"""A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        beta: float = 0.965,
        weight_decay: float = 0.01,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'beta': beta,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Tiger'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.zeros_like(grad)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta = group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)

                p.add_(
                    torch.sign(exp_avg) if not torch.is_complex(exp_avg) else torch.sgn(exp_avg), alpha=-group['lr']
                )

        return loss

TRAC

Bases: BaseOptimizer

A Parameter-Free Optimizer for Lifelong Reinforcement Learning.

Example:

Here's an example::

    model = YourModel()
    optimizer = TRAC(AdamW(model.parameters()))

    for input, output in data:
        optimizer.zero_grad()

        loss = loss_fn(model(input), output)
        loss.backward()

        optimizer.step()

Parameters:

Name Type Description Default
optimizer OPTIMIZER_INSTANCE_OR_CLASS

OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.

required
betas List[float]

List[float]. list of beta values.

(0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999)
num_coefs int

int. the number of polynomial coefficients to use in the approximation.

128
s_prev float

float. initial scale value.

1e-08
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/trac.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class TRAC(BaseOptimizer):
    r"""A Parameter-Free Optimizer for Lifelong Reinforcement Learning.

    Example:
    -------
        Here's an example::

            model = YourModel()
            optimizer = TRAC(AdamW(model.parameters()))

            for input, output in data:
                optimizer.zero_grad()

                loss = loss_fn(model(input), output)
                loss.backward()

                optimizer.step()

    :param optimizer: OPTIMIZER_INSTANCE_OR_CLASS. base optimizer.
    :param betas: List[float]. list of beta values.
    :param num_coefs: int. the number of polynomial coefficients to use in the approximation.
    :param s_prev: float. initial scale value.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER_INSTANCE_OR_CLASS,
        betas: List[float] = (0.9, 0.99, 0.999, 0.9999, 0.99999, 0.999999),
        num_coefs: int = 128,
        s_prev: float = 1e-8,
        eps: float = 1e-8,
        **kwargs,
    ):
        self.validate_positive(num_coefs, 'num_coefs')
        self.validate_non_negative(s_prev, 's_prev')
        self.validate_non_negative(eps, 'eps')

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

        self.betas = betas
        self.s_prev = s_prev
        self.eps = eps

        self.erf: nn.Module = ERF1994(num_coefs=num_coefs)
        self.f_term: torch.Tensor = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))

        self.defaults: DEFAULTS = self.optimizer.defaults

    def __str__(self) -> str:
        return 'TRAC'

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @property
    def state(self) -> STATE:
        return self.optimizer.state

    def state_dict(self) -> STATE:
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict: STATE) -> None:
        self.optimizer.load_state_dict(state_dict)

    def init_group(self, group: GROUP, **kwargs) -> None:
        updates = kwargs.get('updates')
        for p in group['params']:
            self.state['trac'][p] = updates[p].clone()

    @torch.no_grad()
    def zero_grad(self, set_to_none: bool = True) -> None:
        self.optimizer.zero_grad(set_to_none=set_to_none)

    @torch.no_grad()
    def erf_imag(self, x: torch.Tensor) -> torch.Tensor:
        if not torch.is_floating_point(x):
            x = x.real.to(torch.float32)

        ix = torch.complex(torch.zeros_like(x), x)

        return self.erf(ix).imag

    @torch.no_grad()
    def backup_params_and_grads(self) -> Tuple[Dict, Dict]:
        updates, grads = {}, {}

        for group in self.param_groups:
            for p in group['params']:
                updates[p] = p.clone()
                grads[p] = p.grad.clone() if p.grad is not None else None

        return updates, grads

    @torch.no_grad()
    def trac_step(self, updates: Dict, grads: Dict) -> None:
        self.state['trac']['step'] += 1

        deltas = {}

        device = self.param_groups[0]['params'][0].device

        s = self.state['trac']['s']
        h = torch.zeros((1,), device=device)
        for group in self.param_groups:
            for p in group['params']:
                if grads[p] is None:
                    continue

                theta_ref = self.state['trac'][p]
                update = updates[p]

                deltas[p] = (update - theta_ref) / s.add(self.eps)
                update.neg_().add_(p)

                grad, delta = grads[p], deltas[p]

                product = torch.dot(delta.flatten(), grad.flatten())
                h.add_(product)

                delta.add_(update)

                p.copy_(theta_ref)

        betas = self.state['trac']['betas']
        variance = self.state['trac']['variance']
        sigma = self.state['trac']['sigma']

        variance.mul_(betas.pow(2)).add_(h.pow(2))
        sigma.mul_(betas).sub_(h)

        term = self.erf_imag(sigma / (2.0 * variance).sqrt_().add_(self.eps)).mul_(self.f_term)
        s.copy_(torch.sum(term))

        scale = max(s, 0.0)

        for group in self.param_groups:
            for p in group['params']:
                if grads[p] is None:
                    continue

                p.add_(deltas[p] * scale)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        # TODO: backup is first to get the delta of param and grad, but it does not work.
        with torch.enable_grad():
            loss = self.optimizer.step(closure)

        updates, grads = self.backup_params_and_grads()

        if 'trac' not in self.state:
            device = self.param_groups[0]['params'][0].device

            self.state['trac'] = {
                'betas': torch.tensor(self.betas, device=device),
                's': torch.zeros(1, device=device),
                'variance': torch.zeros(len(self.betas), device=device),
                'sigma': torch.full((len(self.betas),), 1e-8, device=device),
                'step': 0,
            }

            for group in self.param_groups:
                self.init_group(group, updates=updates)

        self.trac_step(updates, grads)

        return loss

VSGD

Bases: BaseOptimizer

Variational Stochastic Gradient Descent for Deep Neural Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
ghattg float

float. prior variance ratio between ghat and g, Var(ghat_t-g_t)/Var(g_t-g_{t-1}).

30.0
ps float

float. prior strength.

1e-08
tau1 float

float. remember rate for the gamma parameters of g.

0.81
tau2 float

float. remember rate for the gamma parameter of ghat.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/sgd.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
class VSGD(BaseOptimizer):
    r"""Variational Stochastic Gradient Descent for Deep Neural Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param ghattg: float. prior variance ratio between ghat and g, Var(ghat_t-g_t)/Var(g_t-g_{t-1}).
    :param ps: float. prior strength.
    :param tau1: float. remember rate for the gamma parameters of g.
    :param tau2: float. remember rate for the gamma parameter of ghat.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        ghattg: float = 30.0,
        ps: float = 1e-8,
        tau1: float = 0.81,
        tau2: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        eps: float = 1e-8,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(ghattg, 'ghattg')
        self.validate_non_negative(ps, 'ps')
        self.validate_non_negative(tau1, 'tau1')
        self.validate_non_negative(tau2, 'tau2')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'tau1': tau1,
            'tau2': tau2,
            'pa2': 2.0 * ps + 1.0 + 1e-4,
            'pbg2': 2.0 * ps,
            'pbhg2': 2.0 * ghattg * ps,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'VSGD'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['mug'] = torch.zeros_like(p)
                state['bg'] = torch.zeros_like(p)
                state['bhg'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            pa2, pbg2, pbhg2 = group['pa2'], group['pbg2'], group['pbhg2']

            rho1: float = math.pow(group['step'], -group['tau1'])
            rho2: float = math.pow(group['step'], -group['tau2'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                bg, bhg = state['bg'], state['bhg']

                if group['step'] == 1:
                    sg = pbg2 / (pa2 - 1.0)
                    shg = pbhg2 / (pa2 - 1.0)
                else:
                    sg = bg / pa2
                    shg = bhg / pa2

                mug = state['mug']
                mug_prev = mug.clone()

                mug.mul_(shg).add_(grad * sg).div_(sg + shg)

                sigg = (sg * shg) / (sg + shg)
                mug_sq = mug.pow(2).add_(sigg)

                bg2 = pbg2 + mug_sq - 2.0 * mug * mug_prev + mug_prev.pow(2)
                bhg2 = pbhg2 + mug_sq - 2.0 * grad * mug + grad.pow(2)

                bg.mul_(1.0 - rho1).add_(bg2, alpha=rho1)
                bhg.mul_(1.0 - rho2).add_(bhg2, alpha=rho2)

                p.add_(group['lr'] / mug_sq.sqrt().add_(group['eps']) * mug, alpha=-1.0)

        return loss

WSAM

Bases: BaseOptimizer

Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

Parameters:

Name Type Description Default
model Union[Module, DistributedDataParallel]

Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make model.no_sync to work.

required
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

Optimizer. base optimizer.

required
rho float

float. size of the neighborhood for computing the max loss.

0.05
gamma float

float. weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.

0.9
adaptive bool

bool. element-wise adaptive SAM.

False
decouple bool

bool. whether to perform a decoupled sharpness regularization.

True
max_norm Optional[float]

Optional[float]. max norm of the gradients.

None
eps float

float. term added to the denominator of WSAM to improve numerical stability.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
class WSAM(BaseOptimizer):
    r"""Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

    :param model: Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make
        `model.no_sync` to work.
    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: Optimizer. base optimizer.
    :param rho: float. size of the neighborhood for computing the max loss.
    :param gamma: float. weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.
    :param adaptive: bool. element-wise adaptive SAM.
    :param decouple: bool. whether to perform a decoupled sharpness regularization.
    :param max_norm: Optional[float]. max norm of the gradients.
    :param eps: float. term added to the denominator of WSAM to improve numerical stability.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        model: Union[nn.Module, DistributedDataParallel],
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        rho: float = 0.05,
        gamma: float = 0.9,
        adaptive: bool = False,
        decouple: bool = True,
        max_norm: Optional[float] = None,
        eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')

        self.model = model
        self.decouple = decouple
        self.max_norm = max_norm

        alpha: float = gamma / (1.0 - gamma)

        defaults: DEFAULTS = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps, **kwargs}

        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'WSAM'

    def init_group(self, group: GROUP, **kwargs) -> None:
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        device = self.param_groups[0]['params'][0].device

        grad_norm = get_global_gradient_norm(self.param_groups, device)

        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + group['sam_eps'])

            for p in group['params']:
                if p.grad is None:
                    continue

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)

                p.add_(e_w)

                self.state[p]['e_w'] = e_w

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, op=ReduceOp.AVG)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['grad'] = p.grad.clone()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, ReduceOp.AVG)

                p.add_(self.state[p]['e_w'], alpha=-1.0)

        if self.max_norm is not None:
            clip_grad_norm_(self.model.parameters(), self.max_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if not self.decouple:
                    p.grad.mul_(group['alpha']).add_(self.state[p]['grad'], alpha=1.0 - group['alpha'])
                else:
                    self.state[p]['sharpness'] = p.grad.clone() - self.state[p]['grad']
                    p.grad.mul_(0.0).add_(self.state[p]['grad'], alpha=1.0)

        self.base_optimizer.step()

        if self.decouple:
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    p.add_(self.state[p]['sharpness'], alpha=-group['lr'] * group['alpha'])

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        closure = torch.enable_grad()(closure)

        enable_running_stats(self.model)
        loss = closure()

        self.first_step(zero_grad=True)

        disable_running_stats(self.model)
        closure()

        self.second_step()

        return loss

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

Yogi

Bases: BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
initial_accumulator float

float. initial values for first and second moments.

1e-06
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

0.001
maximize bool

bool. maximize the objective with respect to the params, instead of minimizing.

False
Source code in pytorch_optimizer/optimizer/yogi.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class Yogi(BaseOptimizer):
    r"""Decoupled Weight Decay Regularization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param initial_accumulator: float. initial values for first and second moments.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param maximize: bool. maximize the objective with respect to the params, instead of minimizing.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        betas: BETAS = (0.9, 0.999),
        initial_accumulator: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-3,
        maximize: bool = False,
        **kwargs,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.maximize = maximize

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'initial_accumulator': initial_accumulator,
            'eps': eps,
            **kwargs,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Yogi'

    def init_group(self, group: GROUP, **kwargs) -> None:
        for p in group['params']:
            if p.grad is None:
                continue

            grad = p.grad
            if grad.is_sparse:
                raise NoSparseGradientError(str(self))

            state = self.state[p]

            if len(state) == 0:
                state['exp_avg'] = torch.full_like(grad, fill_value=group['initial_accumulator'])
                state['exp_avg_sq'] = torch.full_like(grad, fill_value=group['initial_accumulator'])

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' not in group:
                self.init_group(group)
                group['step'] = 1
            else:
                group['step'] += 1

            beta1, beta2 = group['betas']

            bias_correction1: float = self.debias(beta1, group['step'])
            bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

            step_size: float = self.apply_adam_debias(
                adam_debias=group.get('adam_debias', False), step_size=group['lr'], bias_correction1=bias_correction1
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                self.maximize_gradient(grad, maximize=self.maximize)

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                grad_p2 = grad.mul(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.addcmul_(
                    (
                        (exp_avg_sq - grad_p2).sign_()
                        if not torch.is_complex(exp_avg_sq)
                        else (exp_avg_sq - grad_p2).sgn_()
                    ),
                    grad_p2,
                    value=-(1.0 - beta2),
                )

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss