Skip to content

Optimizers

A2Grad

Bases: Optimizer, BaseOptimizer

Optimal Adaptive and Accelerated Stochastic Gradient Descent.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr Optional[float]

Optional[float]. learning rate. no needed.

None
beta float

float. beta.

10.0
lips float

float. Lipschitz constant.

10.0
rho float

float. represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.

0.5
variant str

str. type of A2Grad optimizer. 'uni', 'inc', 'exp'.

'uni'
Source code in pytorch_optimizer/optimizer/a2grad.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class A2Grad(Optimizer, BaseOptimizer):
    r"""Optimal Adaptive and Accelerated Stochastic Gradient Descent.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: Optional[float]. learning rate. no needed.
    :param beta: float. beta.
    :param lips: float. Lipschitz constant.
    :param rho: float. represents the degree of weighting decrease, a constant smoothing factor between 0 and 1.
    :param variant: str. type of A2Grad optimizer. 'uni', 'inc', 'exp'.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: Optional[float] = None,
        beta: float = 10.0,
        lips: float = 10.0,
        rho: float = 0.5,
        variant: str = 'uni',
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(lips, 'lips')
        self.validate_non_negative(rho, 'rho')
        self.validate_options(variant, 'variant', ['uni', 'inc', 'exp'])

        self.variant = variant

        defaults: DEFAULTS = {'beta': beta, 'lips': lips}
        if variant == 'exp':
            defaults.update({'rho': rho})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'A2Grad'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['alpha_k'] = 1.0
                state['v_k'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['avg_grad'] = torch.zeros_like(p)
                state['x_k'] = p.clone()
                if self.variant == 'exp':
                    state['v_kk'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            gamma_k: float = 2.0 * group['lips'] / (group['step'] + 1)
            alpha_k_1: float = 2.0 / (group['step'] + 3)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['alpha_k'] = 1.0
                    state['v_k'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    state['avg_grad'] = grad.clone()
                    state['x_k'] = p.clone()
                    if self.variant == 'exp':
                        state['v_kk'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                avg_grad = state['avg_grad']
                avg_grad.add_(grad - avg_grad, alpha=group['step'] + 1)

                delta_k = grad.clone()
                delta_k.add_(avg_grad, alpha=-1.0)

                delta_k_sq = delta_k.pow(2).sum()

                v_k = state['v_k']
                if self.variant in ('uni', 'inc'):
                    if self.variant == 'inc':
                        v_k.mul_((group['step'] / (group['step'] + 1)) ** 2)
                    v_k.add_(delta_k_sq)
                else:
                    v_kk = state['v_kk']
                    v_kk.mul_(group['rho']).add_(delta_k_sq, alpha=1.0 - group['rho'])
                    torch.max(v_kk, v_k, out=v_k)

                h_k = v_k.sqrt()
                if self.variant != 'uni':
                    h_k.mul_(math.sqrt(group['step'] + 1))

                coefficient = -1.0 / (gamma_k + group['beta'] * h_k.item())

                x_k = state['x_k']
                x_k.add_(grad, alpha=coefficient)

                p.mul_(1.0 - alpha_k_1).add_(x_k, alpha=alpha_k_1)
                p.add_(grad, alpha=(1.0 - alpha_k_1) * state['alpha_k'] * coefficient)

                state['alpha_k'] = alpha_k_1

        return loss

AdaBelief

Bases: Optimizer, BaseOptimizer

Adapting Step-sizes by the Belief in Observed Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-16
Source code in pytorch_optimizer/optimizer/adabelief.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class AdaBelief(Optimizer, BaseOptimizer):
    r"""Adapting Step-sizes by the Belief in Observed Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: number of SMA threshold (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-16,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaBelief'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_var'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    if group['ams_bound']:
                        state['max_exp_avg_var'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                grad_residual = grad - exp_avg
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )

                if not group['rectify']:
                    de_nom.div_(bias_correction2_sq)
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

AdaBound

Bases: Optimizer, BaseOptimizer

Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
final_lr float

float. final learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
gamma float

float. convergence speed of the bound functions.

0.001
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adabound.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class AdaBound(Optimizer, BaseOptimizer):
    r"""Adaptive Gradient Methods with Dynamic Bound of Learning Rate.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param final_lr: float. final learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param gamma: float. convergence speed of the bound functions.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        final_lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        gamma: float = 1e-3,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'final_lr': final_lr,
            'gamma': gamma,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'AdaBound'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group, base_lr in zip(self.param_groups, self.base_lrs):
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            final_lr: float = group['final_lr'] * group['lr'] / base_lr
            lower_bound: float = final_lr * (1 - 1 / (group['gamma'] * group['step'] + 1))
            upper_bound: float = final_lr * (1 + 1 / (group['gamma'] * group['step']))

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if group['ams_bound']:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                step_size = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'] * bias_correction2_sq,
                    bias_correction1=bias_correction1,
                )

                step_size = torch.full_like(de_nom, fill_value=step_size)
                step_size.div_(de_nom).clamp_(min=lower_bound, max=upper_bound).mul_(exp_avg)

                p.add_(-step_size)

        return loss

AdaDelta

Bases: Optimizer, BaseOptimizer

An Adaptive Learning Rate Method.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
rho float

float. coefficient used for computing a running average of squared gradients.

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
Source code in pytorch_optimizer/optimizer/adadelta.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
class AdaDelta(Optimizer, BaseOptimizer):
    r"""An Adaptive Learning Rate Method.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param rho: float. coefficient used for computing a running average of squared gradients.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        rho: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(rho, 'rho', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'rho': rho,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaDelta'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['square_avg'] = torch.zeros_like(p)
                state['acc_delta'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            rho: float = group['rho']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['square_avg'] = torch.zeros_like(p)
                    state['acc_delta'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                square_avg, acc_delta = state['square_avg'], state['acc_delta']
                square_avg.mul_(rho).addcmul_(grad, grad, value=1.0 - rho)

                std = square_avg.add(group['eps']).sqrt_()
                delta = acc_delta.add(group['eps']).sqrt_().div_(std).mul_(grad)

                acc_delta.mul_(rho).addcmul_(delta, delta, value=1.0 - rho)
                p.add_(delta, alpha=-group['lr'])

        return loss

AdaFactor

Bases: Optimizer, BaseOptimizer

Adaptive Learning Rates with Sublinear Memory Cost.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr Optional[float]

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
decay_rate float

float. coefficient used to compute running averages of square gradient.

-0.8
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
clip_threshold float

float. threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

bool. whether to use the AMSBound variant.

False
scale_parameter bool

bool. if true, learning rate is scaled by root-mean-square of parameter.

True
relative_step bool

bool. if true, time-dependent learning rate is computed instead of external learning rate.

True
warmup_init bool

bool. time-dependent learning rate computation depends on whether warm-up initialization is being used.

False
eps1 float

float. term added to the denominator to improve numerical stability.

1e-30
eps2 float

float. term added to the denominator to improve numerical stability.

0.001
Source code in pytorch_optimizer/optimizer/adafactor.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class AdaFactor(Optimizer, BaseOptimizer):
    r"""Adaptive Learning Rates with Sublinear Memory Cost.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param decay_rate: float. coefficient used to compute running averages of square gradient.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param clip_threshold: float. threshold of root-mean-square of final gradient update.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param scale_parameter: bool. if true, learning rate is scaled by root-mean-square of parameter.
    :param relative_step: bool. if true, time-dependent learning rate is computed instead of external learning rate.
    :param warmup_init: bool. time-dependent learning rate computation depends on whether warm-up initialization
        is being used.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: Optional[float] = 1e-3,
        betas: BETAS = (0.9, 0.999),
        decay_rate: float = -0.8,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        scale_parameter: bool = True,
        relative_step: bool = True,
        warmup_init: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-3,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.decay_rate = decay_rate
        self.clip_threshold = clip_threshold
        self.eps1 = eps1
        self.eps2 = eps2

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'scale_parameter': scale_parameter,
            'relative_step': relative_step,
            'warmup_init': warmup_init,
            'eps1': eps1,
            'eps2': eps2,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaFactor'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                grad = p.grad

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                state['exp_avg'] = torch.zeros_like(p)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                state['RMS'] = 0.0

    def get_lr(
        self, lr: float, step: int, rms: float, relative_step: bool, warmup_init: bool, scale_parameter: bool
    ) -> float:
        r"""Get AdaFactor learning rate."""
        relative_step_size: float = lr
        if relative_step:
            min_step: float = 1e-6 * step if warmup_init else 1e-2
            relative_step_size = min(min_step, 1.0 / math.sqrt(step))

        param_scale: float = 1.0 if scale_parameter else max(self.eps2, rms)

        return param_scale * relative_step_size

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    @staticmethod
    def get_rms(x: torch.Tensor) -> float:
        r"""Get RMS."""
        return x.norm(2) / math.sqrt(x.numel())

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: torch.Tensor,
        exp_avg_sq_col: torch.Tensor,
        output: torch.Tensor,
    ):
        r"""Get approximation of EMA of squared gradient."""
        r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, _ = group['betas']

            beta2_t: float = 1.0 - math.pow(group['step'], self.decay_rate)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                    if factored:
                        state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                        state['exp_avg_sq_col'] = torch.zeros(
                            grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                        )
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(grad)

                    if group['ams_bound']:
                        state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                    state['RMS'] = 0.0

                state['RMS'] = self.get_rms(p)

                lr: float = self.get_lr(
                    lr=group['lr'],
                    step=group['step'],
                    rms=state['RMS'],
                    relative_step=group['relative_step'],
                    warmup_init=group['warmup_init'],
                    scale_parameter=group['scale_parameter'],
                )

                update = torch.mul(grad, grad).add_(self.eps1)

                if factored:
                    exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2_t).add_(update.mean(dim=-1), alpha=1.0 - beta2_t)
                    exp_avg_sq_col.mul_(beta2_t).add_(update.mean(dim=-2), alpha=1.0 - beta2_t)

                    self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(beta2_t).add_(update, alpha=1.0 - beta2_t)
                    torch.rsqrt(exp_avg_sq, out=update)

                if group['ams_bound']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat']
                    torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
                    torch.rsqrt(exp_avg_sq_hat / beta2_t, out=update)

                update.mul_(grad)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0)).mul_(lr)

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(-exp_avg)

        return loss

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/optimizer/adafactor.py
128
129
130
131
132
133
134
135
136
137
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: torch.Tensor,
    exp_avg_sq_col: torch.Tensor,
    output: torch.Tensor,
):
    r"""Get approximation of EMA of squared gradient."""
    r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
    c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
    torch.mul(r_factor, c_factor, out=output)

get_lr(lr, step, rms, relative_step, warmup_init, scale_parameter)

Get AdaFactor learning rate.

Source code in pytorch_optimizer/optimizer/adafactor.py
105
106
107
108
109
110
111
112
113
114
115
116
def get_lr(
    self, lr: float, step: int, rms: float, relative_step: bool, warmup_init: bool, scale_parameter: bool
) -> float:
    r"""Get AdaFactor learning rate."""
    relative_step_size: float = lr
    if relative_step:
        min_step: float = 1e-6 * step if warmup_init else 1e-2
        relative_step_size = min(min_step, 1.0 / math.sqrt(step))

    param_scale: float = 1.0 if scale_parameter else max(self.eps2, rms)

    return param_scale * relative_step_size

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/adafactor.py
118
119
120
121
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/optimizer/adafactor.py
123
124
125
126
@staticmethod
def get_rms(x: torch.Tensor) -> float:
    r"""Get RMS."""
    return x.norm(2) / math.sqrt(x.numel())

AdaHessian

Bases: Optimizer, BaseOptimizer

An Adaptive Second Order Optimizer for Machine Learning.

Requires `loss.backward(create_graph=True)` in order to calculate hessians.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
hessian_power float

float. exponent of the hessian trace.

1.0
update_period int

int. number of steps after which to apply hessian approximation.

1
num_samples int

int. times to sample z for the approximation of the hessian trace.

1
hessian_distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution to initialize hessian.

'rademacher'
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-16
Source code in pytorch_optimizer/optimizer/adahessian.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class AdaHessian(Optimizer, BaseOptimizer):
    r"""An Adaptive Second Order Optimizer for Machine Learning.

        Requires `loss.backward(create_graph=True)` in order to calculate hessians.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param hessian_power: float. exponent of the hessian trace.
    :param update_period: int. number of steps after which to apply hessian approximation.
    :param num_samples: int. times to sample `z` for the approximation of the hessian trace.
    :param hessian_distribution: HUTCHINSON_G. type of distribution to initialize hessian.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        hessian_power: float = 1.0,
        update_period: int = 1,
        num_samples: int = 1,
        hessian_distribution: HUTCHINSON_G = 'rademacher',
        adam_debias: bool = False,
        eps: float = 1e-16,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')
        self.validate_range(hessian_power, 'Hessian Power', 0, 1, range_type='(]')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'hessian_power': hessian_power,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaHessian'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_hessian_diag_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2: float = 1.0 - beta2 ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if 'exp_avg' not in state:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_hessian_diag_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    exp_hessian_diag_sq.mul_(beta2).addcmul_(state['hessian'], state['hessian'], value=1.0 - beta2)

                de_nom = (exp_hessian_diag_sq / bias_correction2).pow_(group['hessian_power'] / 2).add_(group['eps'])

                step_size: float = self.apply_adam_debias(group['adam_debias'], group['lr'], bias_correction1)
                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

Adai

Bases: Optimizer, BaseOptimizer

Disentangling the Effects of Adaptive Learning Rate and Momentum.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.1, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
stable_weight_decay bool

bool. perform stable weight decay.

False
dampening float

float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.

1.0
use_gc bool

bool. use gradient centralization.

False
eps float

float. term added to the denominator to improve numerical stability.

0.001
Source code in pytorch_optimizer/optimizer/adai.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class Adai(Optimizer, BaseOptimizer):
    r"""Disentangling the Effects of Adaptive Learning Rate and Momentum.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param stable_weight_decay: bool. perform stable weight decay.
    :param dampening: float. dampening for momentum. where dampening < 1, it will show some adaptive-moment behavior.
    :param use_gc: bool. use gradient centralization.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.1, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        stable_weight_decay: bool = False,
        dampening: float = 1.0,
        use_gc: bool = False,
        eps: float = 1e-3,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.use_gc = use_gc

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'stable_weight_decay': stable_weight_decay,
            'dampening': dampening,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adai'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['beta1_prod'] = torch.ones_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            _, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                param_size += p.numel()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['beta1_prod'] = torch.ones_like(p)

                state['step'] += 1

                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                bias_correction2: float = 1.0 - beta2 ** state['step']

                if not group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                exp_avg_sq_hat_sum += exp_avg_sq.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError()

        exp_avg_sq_hat_mean = exp_avg_sq_hat_sum / param_size

        for group in self.param_groups:
            beta0, beta2 = group['betas']
            beta0_dp: float = math.pow(beta0, 1.0 - group['dampening'])
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                if group['stable_weight_decay'] and group['weight_decay'] > 0.0:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                bias_correction2: float = 1.0 - beta2 ** state['step']

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                exp_avg_sq_hat = exp_avg_sq / bias_correction2
                beta1 = (
                    1.0
                    - (exp_avg_sq_hat / exp_avg_sq_hat_mean).pow_(1.0 / (3.0 - 2.0 * group['dampening'])).mul_(beta0)
                ).clamp_(0.0, 1.0 - group['eps'])
                beta3 = (1.0 - beta1).pow_(group['dampening'])

                beta1_prod = state['beta1_prod']
                beta1_prod.mul_(beta1)

                exp_avg.mul_(beta1).addcmul_(beta3, grad)
                exp_avg_hat = exp_avg.div(1.0 - beta1_prod).mul_(beta0_dp)

                p.add_(exp_avg_hat, alpha=-group['lr'])

        return loss

Adalite

Bases: Optimizer, BaseOptimizer

Adalite optimizer.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
g_norm_min float

float.

1e-10
ratio_min float

float.

0.0001
tau float

float.

1.0
eps1 float

float. term added to the denominator to improve numerical stability.

1e-06
eps2 float

float. term added to the denominator to improve numerical stability.

1e-10
Source code in pytorch_optimizer/optimizer/adalite.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class Adalite(Optimizer, BaseOptimizer):
    r"""Adalite optimizer.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param g_norm_min: float.
    :param ratio_min: float.
    :param tau: float.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 1e-2,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        g_norm_min: float = 1e-10,
        ratio_min: float = 1e-4,
        tau: float = 1.0,
        eps1: float = 1e-6,
        eps2: float = 1e-10,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps1')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'g_norm_min': g_norm_min,
            'ratio_min': ratio_min,
            'tau': tau,
            'eps1': eps1,
            'eps2': eps2,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adalite'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                if len(p.shape) < 2:
                    state['m_avg'] = torch.zeros_like(p)
                    state['v_avg'] = torch.zeros_like(p)
                else:
                    state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
                    state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

                    state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
                    state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
                    state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    if len(p.shape) < 2:
                        state['m_avg'] = torch.zeros_like(p)
                        state['v_avg'] = torch.zeros_like(p)
                    else:
                        state['v_avg_0'] = torch.zeros_like(p.mean(dim=1))
                        state['v_avg_1'] = torch.zeros_like(p.mean(dim=0))

                        state['m_avg_c'] = torch.zeros_like(p.mean(dim=1)[:, None])
                        state['m_avg_r'] = torch.zeros_like(p.mean(dim=0)[None, :])
                        state['m_avg_u'] = torch.zeros_like(p.mean().unsqueeze(0).unsqueeze(0))

                if sum(grad.shape) > 1:
                    trust_ratio = (p.norm() / grad.norm().clip(min=group['g_norm_min'])).clip(min=group['ratio_min'])
                    grad.mul_(trust_ratio)

                if len(grad.shape) < 2:
                    m = state['m_avg']
                    v = state['v_avg']
                else:
                    r, c = state['v_avg_0'][:, None], state['v_avg_1'][None, :]
                    v = (r * c) / r.sum().clamp(min=group['eps2'])
                    m = state['m_avg_c'] @ state['m_avg_u'] @ state['m_avg_r']

                m.lerp_(grad, 1.0 - beta1)
                v.lerp_((grad - m).square(), 1.0 - beta2)

                v_avg = v / (1.0 - beta2 ** group['step'])

                if len(grad.shape) == 2:
                    imp_c = softmax(v.mean(dim=1), dim=0)[:, None]
                    imp_r = softmax(v.mean(dim=0), dim=0)[None, :]
                    m.lerp_(grad, 1.0 - imp_c * imp_r)

                u = m.lerp(grad, 1.0 - beta1)

                if len(grad.shape) < 2:
                    state['m_avg'] = m
                    state['v_avg'] = v
                else:
                    state['v_avg_0'] = v.sum(dim=1)
                    state['v_avg_1'] = v.sum(dim=0) / v.sum().clamp(min=group['eps2'])

                    imp_c = softmax(v.mean(dim=1) / group['tau'], dim=-1)[:, None]
                    imp_r = softmax(v.mean(dim=0) / group['tau'], dim=-1)[None, :]

                    c = ((m * imp_r).sum(dim=1))[:, None]
                    r = ((m * imp_c).sum(dim=0))[None, :]

                    s = (c.T @ m @ r.T) / (c.T @ c @ r @ r.T).clamp(min=group['eps2'])

                    state['m_avg_c'] = c
                    state['m_avg_r'] = r
                    state['m_avg_u'] = s

                u.div_((v_avg + group['eps1']).sqrt())

                u = u.reshape(p.shape)
                u.add_(p, alpha=group['weight_decay'])

                p.add_(u, alpha=-group['lr'])

        return loss

AdaMax

Bases: Optimizer, BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adamax.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class AdaMax(Optimizer, BaseOptimizer):
    r"""An Adaptive and Momental Bound Method for Stochastic Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMax'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_inf'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_inf'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                norm_buf = torch.cat(
                    (exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(group['eps']).unsqueeze_(0)),
                    dim=0,
                )
                torch.max(norm_buf, dim=0, keepdim=False, out=(exp_inf, exp_inf.new().long()))

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'],
                    bias_correction1=bias_correction1,
                )

                p.addcdiv_(exp_avg, exp_inf, value=-step_size)

        return loss

AdaMod

Bases: Optimizer, BaseOptimizer

An Adaptive and Momental Bound Method for Stochastic Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace. beta3 is for smoothing coefficient for adaptive learning rates.

(0.9, 0.99, 0.9999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adamod.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class AdaMod(Optimizer, BaseOptimizer):
    r"""An Adaptive and Momental Bound Method for Stochastic Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
        beta3 is for smoothing coefficient for adaptive learning rates.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.99, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaMod'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_lr'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['exp_avg_lr'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                step_size = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'] * bias_correction2_sq,
                    bias_correction1=bias_correction1,
                )

                step_size = torch.full_like(de_nom, fill_value=step_size)
                step_size.div_(de_nom)

                exp_avg_lr = state['exp_avg_lr']
                exp_avg_lr.mul_(beta3).add_(step_size, alpha=1.0 - beta3)

                torch.min(step_size, exp_avg_lr, out=step_size)
                step_size.mul_(exp_avg)

                p.add_(-step_size)

        return loss

AdamP

Bases: Optimizer, BaseOptimizer

Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
delta float

float. threshold that determines whether a set of parameters is scale invariant or not.

0.1
wd_ratio float

float. relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
use_gc bool

bool. use gradient centralization.

False
nesterov bool

bool. enables Nesterov momentum.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adamp.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class AdamP(Optimizer, BaseOptimizer):
    r"""Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param delta: float. threshold that determines whether a set of parameters is scale invariant or not.
    :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
        on scale-variant parameters.
    :param use_gc: bool. use gradient centralization.
    :param nesterov: bool. enables Nesterov momentum.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        use_gc: bool = False,
        nesterov: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        self.use_gc = use_gc

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamP'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                inv_de_nom = exp_avg_sq.rsqrt().add_(group['eps']).mul_(bias_correction2_sq)

                perturb = exp_avg.clone()
                if group['nesterov']:
                    perturb.mul_(beta1).addcmul_(grad, inv_de_nom, value=1.0 - beta1)
                else:
                    perturb.mul_(inv_de_nom)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    perturb, wd_ratio = projection(
                        p,
                        grad,
                        perturb,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio,
                )

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'],
                    bias_correction1=bias_correction1,
                )

                p.add_(perturb, alpha=-step_size)

        return loss

AdamS

Bases: Optimizer, BaseOptimizer

Adam with stable weight decay.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0001
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adams.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class AdamS(Optimizer, BaseOptimizer):
    r"""Adam with stable weight decay.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdamS'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        exp_avg_sq_hat_sum: float = 0.0

        for group in self.param_groups:
            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                param_size += p.numel()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if group['ams_bound']:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

                state['step'] += 1

                bias_correction2: float = 1.0 - beta2 ** state['step']

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if group['ams_bound']:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    exp_avg_sq_hat = max_exp_avg_sq
                else:
                    exp_avg_sq_hat = exp_avg_sq

                exp_avg_sq_hat_sum += exp_avg_sq_hat.sum() / bias_correction2

        if param_size == 0:
            raise ZeroParameterSizeError()

        exp_avg_sq_hat_mean: float = math.sqrt(exp_avg_sq_hat_sum / param_size) + self.defaults['eps']

        for group in self.param_groups:
            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / exp_avg_sq_hat_mean,
                )

                bias_correction1: float = 1.0 - beta1 ** state['step']
                bias_correction2: float = 1.0 - beta2 ** state['step']

                exp_avg_sq_hat = state['max_exp_avg_sq'] if group['ams_bound'] else state['exp_avg_sq']
                exp_avg_sq_hat.div_(bias_correction2)

                de_nom = exp_avg_sq_hat.sqrt().add_(group['eps'])

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'],
                    bias_correction1=bias_correction1,
                )

                p.addcdiv_(state['exp_avg'], de_nom, value=-step_size)

        return loss

Adan

Bases: Optimizer, BaseOptimizer

Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. decoupled weight decay.

False
max_grad_norm float

float. max gradient norm to clip.

0.0
use_gc bool

bool. use gradient centralization.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adan.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class Adan(Optimizer, BaseOptimizer):
    r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param max_grad_norm: float. max gradient norm to clip.
    :param use_gc: bool. use gradient centralization.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        max_grad_norm: float = 0.0,
        use_gc: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.max_grad_norm = max_grad_norm
        self.use_gc = use_gc

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'max_grad_norm': max_grad_norm,
            'adanorm': adanorm,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Adan'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)
                state['previous_grad'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        clip_global_grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2, beta3 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2: float = 1.0 - beta2 ** group['step']
            bias_correction3_sq: float = math.sqrt(1.0 - beta3 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['exp_avg_diff'] = torch.zeros_like(p)
                    state['previous_grad'] = grad.clone().mul_(-clip_global_grad_norm)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                grad.mul_(clip_global_grad_norm)

                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                grad_diff = state['previous_grad']
                grad_diff.add_(grad)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=1.0 - beta2)

                grad_diff.mul_(beta2).add_(grad)
                exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

                de_nom = exp_avg_sq.sqrt().div_(bias_correction3_sq).add_(group['eps'])

                if group['weight_decouple']:
                    p.mul_(1.0 - group['lr'] * group['weight_decay'])

                p.addcdiv_(exp_avg, de_nom, value=-group['lr'] / bias_correction1)
                p.addcdiv_(exp_avg_diff, de_nom, value=-group['lr'] * beta2 / bias_correction2)

                if not group['weight_decouple']:
                    p.div_(1.0 + group['lr'] * group['weight_decay'])

                state['previous_grad'].copy_(-grad)

        return loss

AdaNorm

Bases: Optimizer, BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the AMSBound variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adanorm.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class AdaNorm(Optimizer, BaseOptimizer):
    r"""Symbolic Discovery of Optimization Algorithms.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.99),
        r: float = 0.95,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'r': r,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaNorm'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)
                state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_var'] = torch.zeros_like(p)
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                    if group['ams_bound']:
                        state['max_exp_avg_var'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=True,
                    exp_grad_norm=state['exp_grad_norm'],
                    r=group['r'],
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_var.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'],
                    bias_correction1=bias_correction1,
                )

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss

AdaPNM

Bases: Optimizer, BaseOptimizer

Adam + Positive-Negative Momentum Optimizers.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use weight_decouple.

True
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the ams_bound variant.

True
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/adapnm.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class AdaPNM(Optimizer, BaseOptimizer):
    r"""Adam + Positive-Negative Momentum Optimizers.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use weight_decouple.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the ams_bound variant.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        ams_bound: bool = True,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaPNM'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['neg_exp_avg'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2, beta3 = group['betas']

            noise_norm: float = math.sqrt((1 + beta3) ** 2 + beta3 ** 2)  # fmt: skip
            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['neg_exp_avg'] = torch.zeros_like(p)
                    if group['ams_bound']:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['step'] % 2 == 1:
                    exp_avg, neg_exp_avg = state['exp_avg'], state['neg_exp_avg']
                else:
                    exp_avg, neg_exp_avg = state['neg_exp_avg'], state['exp_avg']

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg_sq = state['exp_avg_sq']
                exp_avg.mul_(beta1 ** 2).add_(s_grad, alpha=1.0 - beta1 ** 2)  # fmt: skip
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )
                de_nom.div_(bias_correction2_sq)

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'], step_size=group['lr'], bias_correction1=bias_correction1
                )

                pn_momentum = exp_avg.mul(1.0 + beta3).add_(neg_exp_avg, alpha=-beta3).mul_(1.0 / noise_norm)
                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        return loss

AdaShift

Bases: Optimizer, BaseOptimizer

Decorrelation and Convergence of Adaptive Learning Rate Methods.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
keep_num int

int. number of gradients used to compute first moment estimation.

10
reduce_func Optional[Callable]

Optional[Callable]. function applied to squared gradients to further reduce the correlation. If None, no function is applied.

max
eps float

float. term added to the denominator to improve numerical stability.

1e-10
Source code in pytorch_optimizer/optimizer/adashift.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class AdaShift(Optimizer, BaseOptimizer):
    r"""Decorrelation and Convergence of Adaptive Learning Rate Methods.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param keep_num: int. number of gradients used to compute first moment estimation.
    :param reduce_func: Optional[Callable]. function applied to squared gradients to further reduce the correlation.
        If None, no function is applied.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        keep_num: int = 10,
        reduce_func: Optional[Callable] = torch.max,
        eps: float = 1e-10,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_positive(keep_num, 'keep_num')
        self.validate_non_negative(eps, 'eps')

        self.reduce_func: Callable = reduce_func if reduce_func is not None else lambda x: x

        defaults: DEFAULTS = {'lr': lr, 'betas': betas, 'keep_num': keep_num, 'eps': eps}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaShift'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['grad_queue'] = deque([p.grad.clone()], maxlen=group['keep_num'])
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            exp_weight_sum: int = sum(beta1 ** i for i in range(group['keep_num']))  # fmt: skip
            first_grad_weight: float = beta1 ** (group['keep_num'] - 1) / exp_weight_sum
            last_grad_weight: float = 1.0 / exp_weight_sum

            bias_correction: float = 1.0 - beta2 ** (group['step'] - group['keep_num'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['grad_queue'] = deque([grad.clone()], maxlen=group['keep_num'])
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                grad_queue = state['grad_queue']
                grad_queue.append(grad.clone())

                if len(grad_queue) != group['keep_num']:
                    continue

                offset_grad = grad_queue[0]

                exp_avg = state['exp_avg']
                exp_avg.sub_(offset_grad, alpha=first_grad_weight).mul_(beta1).add_(grad, alpha=last_grad_weight)

                reduced_grad_sq = self.reduce_func(offset_grad.mul_(offset_grad))

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta2).add_(reduced_grad_sq, alpha=1.0 - beta2)

                de_nom = exp_avg_sq.div(bias_correction).sqrt_().add_(group['eps'])

                p.addcdiv_(exp_avg, de_nom, value=-group['lr'])

        return loss

AdaSmooth

Bases: Optimizer, BaseOptimizer

An Adaptive Learning Rate Method based on Effective Ratio.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.5, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
Source code in pytorch_optimizer/optimizer/adasmooth.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class AdaSmooth(Optimizer, BaseOptimizer):
    r"""An Adaptive Learning Rate Method based on Effective Ratio.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.5, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AdaSmooth'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['prev_param'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)
                state['n'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['prev_param'] = torch.zeros_like(p)
                    state['s'] = torch.zeros_like(p)
                    state['n'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                prev_param = state['prev_param']
                p_diff = p - prev_param

                s, n = state['s'], state['n']
                s.add_(p_diff)
                n.add_(p_diff.abs())

                c = s.sum().abs_().div_(n.sum())  # e_t
                c.mul_(beta2 - beta1).add_(1.0 - beta2)

                c_p2 = c.pow(2)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(1.0 - c_p2).addcmul_(grad, grad, value=c_p2)

                step_size = torch.full_like(exp_avg_sq, fill_value=group['lr'])
                step_size.div_((exp_avg_sq + group['eps']).sqrt()).mul_(grad)

                p.add_(-step_size)

                state['prev_param'].copy_(p)

        return loss

agc(p, grad, agc_eps, agc_clip_val, eps=1e-06)

Clip gradient values in excess of the unit wise norm.

Parameters:

Name Type Description Default
p Tensor

torch.Tensor. parameter.

required
grad Tensor

torch.Tensor, gradient.

required
agc_eps float

float. agc epsilon to clip the norm of parameter.

required
agc_clip_val float

float. norm clip.

required
eps float

float. simple stop from div by zero and no relation to standard optimizer eps.

1e-06
Source code in pytorch_optimizer/optimizer/agc.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def agc(p: torch.Tensor, grad: torch.Tensor, agc_eps: float, agc_clip_val: float, eps: float = 1e-6) -> torch.Tensor:
    r"""Clip gradient values in excess of the unit wise norm.

    :param p: torch.Tensor. parameter.
    :param grad: torch.Tensor, gradient.
    :param agc_eps: float. agc epsilon to clip the norm of parameter.
    :param agc_clip_val: float. norm clip.
    :param eps: float. simple stop from div by zero and no relation to standard optimizer eps.
    """
    p_norm = unit_norm(p).clamp_(agc_eps)
    g_norm = unit_norm(grad)

    max_norm = p_norm * agc_clip_val

    clipped_grad = grad * (max_norm / g_norm.clamp_min_(eps))

    return torch.where(g_norm > max_norm, clipped_grad, grad)

AggMo

Bases: Optimizer, BaseOptimizer

Aggregated Momentum: Stability Through Passive Damping.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.0, 0.9, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/aggmo.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class AggMo(Optimizer, BaseOptimizer):
    r"""Aggregated Momentum: Stability Through Passive Damping.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.0, 0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AggMo'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']}

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            betas = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in betas}

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                for beta in betas:
                    buf = state['momentum_buffer'][beta]
                    buf.mul_(beta).add_(grad)

                    p.add_(buf, alpha=-group['lr'] / len(betas))

        return loss

Aida

Bases: Optimizer, BaseOptimizer

A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
k int

int. number of vector projected per iteration.

2
xi float

float. term used in vector projections to avoid division by zero.

1e-20
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

number of SMA threshold (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/aida.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class Aida(Optimizer, BaseOptimizer):
    r"""A DNN Optimizer that Improves over AdaBelief by Suppression of the Adaptive Stepsize Range.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param k: int. number of vector projected per iteration.
    :param xi: float. term used in vector projections to avoid division by zero.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: number of SMA threshold (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        k: int = 2,
        xi: float = 1e-20,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(eps, 'eps')

        self.k = k
        self.xi = xi
        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Aida'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_var'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['ams_bound']:
                    state['max_exp_avg_var'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_var'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    if group['ams_bound']:
                        state['max_exp_avg_var'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)

                proj_g = grad.detach().clone()
                proj_m = exp_avg.detach().clone()

                for _ in range(self.k):
                    proj_sum_gm = torch.sum(torch.mul(proj_g, proj_m))

                    scalar_g = proj_sum_gm / (torch.sum(torch.pow(proj_g, 2)).add_(self.xi))
                    scalar_m = proj_sum_gm / (torch.sum(torch.pow(proj_m, 2)).add_(self.xi))

                    proj_g.mul_(scalar_g)
                    proj_m.mul_(scalar_m)

                grad_residual = proj_m - proj_g
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2).add_(group['eps'])

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_var,
                    max_exp_avg_sq=state.get('max_exp_avg_var', None),
                    eps=group['eps'],
                )

                if not group['rectify']:
                    de_nom.div_(bias_correction2_sq)
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

AliG

Bases: Optimizer, BaseOptimizer

Adaptive Learning Rates for Interpolation with Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
max_lr Optional[float]

Optional[float]. max learning rate.

None
projection_fn Optional[Callable]

Callable. projection function to enforce constraints.

None
momentum float

float. momentum.

0.0
adjusted_momentum bool

bool. if True, use pytorch-like momentum, instead of standard Nesterov momentum.

False
Source code in pytorch_optimizer/optimizer/alig.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class AliG(Optimizer, BaseOptimizer):
    r"""Adaptive Learning Rates for Interpolation with Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param max_lr: Optional[float]. max learning rate.
    :param projection_fn: Callable. projection function to enforce constraints.
    :param momentum: float. momentum.
    :param adjusted_momentum: bool. if True, use pytorch-like momentum, instead of standard Nesterov momentum.
    """

    def __init__(
        self,
        params: PARAMETERS,
        max_lr: Optional[float] = None,
        projection_fn: Optional[Callable] = None,
        momentum: float = 0.0,
        adjusted_momentum: bool = False,
    ):
        self.validate_learning_rate(max_lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)

        self.projection_fn = projection_fn

        defaults: DEFAULTS = {'max_lr': max_lr, 'adjusted_momentum': adjusted_momentum, 'momentum': momentum}
        super().__init__(params, defaults)

        if self.projection_fn is not None:
            self.projection_fn()

    def __str__(self) -> str:
        return 'AliG'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                if group['momentum'] > 0.0:
                    state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def compute_step_size(self, loss: float) -> float:
        r"""Compute step_size."""
        global_grad_norm = get_global_gradient_norm(self.param_groups, torch.device('cpu'))
        global_grad_norm.add_(1e-6)

        return loss / global_grad_norm.item()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        if closure is None:
            raise NoClosureError('AliG', '(e.g. `optimizer.step(lambda: float(loss))`).')

        loss = closure()

        un_clipped_step_size: float = self.compute_step_size(loss)

        for group in self.param_groups:
            step_size = group['step_size'] = (
                min(un_clipped_step_size, group['max_lr']) if group['max_lr'] is not None else un_clipped_step_size
            )
            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0 and momentum > 0.0:
                    state['momentum_buffer'] = torch.zeros_like(p)

                p.add_(grad, alpha=-step_size)

                if momentum > 0.0:
                    buffer = state['momentum_buffer']

                    if group['adjusted_momentum']:
                        buffer.mul_(momentum).sub_(grad)
                        p.add_(buffer, alpha=step_size * momentum)
                    else:
                        buffer.mul_(momentum).add_(grad, alpha=-step_size)
                        p.add_(buffer, alpha=momentum)

            if self.projection_fn is not None:
                self.projection_fn()

        return loss

compute_step_size(loss)

Compute step_size.

Source code in pytorch_optimizer/optimizer/alig.py
53
54
55
56
57
58
59
@torch.no_grad()
def compute_step_size(self, loss: float) -> float:
    r"""Compute step_size."""
    global_grad_norm = get_global_gradient_norm(self.param_groups, torch.device('cpu'))
    global_grad_norm.add_(1e-6)

    return loss / global_grad_norm.item()

Amos

Bases: Optimizer, BaseOptimizer

An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
beta float

float. A float slightly < 1. We recommend setting 1 - beta to the same order of magnitude as the learning rate. similarity with beta2 in Adam.

0.999
momentum float

float. Exponential decay rate for optional moving average of updates.

0.0
extra_l2 float

float. Additional L2 regularization.

0.0
c_coef float

float. Coefficient for decay_factor_c.

0.25
d_coef float

float. Coefficient for decay_factor_d.

0.25
eps float

float. term added to the denominator to improve numerical stability.

1e-18
Source code in pytorch_optimizer/optimizer/amos.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class Amos(Optimizer, BaseOptimizer):
    r"""An Adam-style Optimizer with Adaptive Weight Decay towards Model-Oriented Scale.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. A float slightly < 1. We recommend setting `1 - beta` to the same order of magnitude
        as the learning rate. similarity with beta2 in Adam.
    :param momentum: float. Exponential decay rate for optional moving average of updates.
    :param extra_l2: float. Additional L2 regularization.
    :param c_coef: float. Coefficient for decay_factor_c.
    :param d_coef: float. Coefficient for decay_factor_d.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        beta: float = 0.999,
        momentum: float = 0.0,
        extra_l2: float = 0.0,
        c_coef: float = 0.25,
        d_coef: float = 0.25,
        eps: float = 1e-18,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(extra_l2, 'extra_l2')
        self.validate_non_negative(eps, 'eps')

        self.c_coef = c_coef
        self.d_coef = d_coef

        defaults: DEFAULTS = {
            'lr': lr,
            'beta': beta,
            'momentum': momentum,
            'extra_l2': extra_l2,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Amos'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg_sq'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['decay'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['momentum'] > 0.0:
                    state['exp_avg'] = torch.zeros_like(p)

    @staticmethod
    def get_scale(p: torch.Tensor) -> float:
        r"""Get expected scale for model weights."""
        if len(p.shape) == 1:  # expected 'bias'
            return 0.5
        if len(p.shape) == 2:  # expected Embedding, Linear, ...
            return math.sqrt(2 / p.size(1))
        return math.sqrt(1 / p.size(1))

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            momentum, beta = group['momentum'], group['beta']

            lr_sq: float = math.sqrt(group['lr'])
            bias_correction: float = 1.0 - beta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg_sq'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                    state['decay'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                    if group['momentum'] > 0.0:
                        state['exp_avg'] = torch.zeros_like(p)

                g2 = grad.pow(2).mean()
                init_lr: float = group['lr'] * self.get_scale(p)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(beta).add_(g2, alpha=1.0 - beta)

                r_v_hat = bias_correction / (exp_avg_sq + group['eps'])

                b = state['decay']
                decay_factor_c = torch.rsqrt(1.0 + self.c_coef * lr_sq * b)
                decay_factor_d = torch.reciprocal(1.0 + self.d_coef * math.sqrt(init_lr) * b)

                gamma = decay_factor_c * (group['lr'] ** 2) * r_v_hat * g2

                update = p.clone()
                update.mul_((gamma - group['extra_l2']) / 2.0)
                update.add_(r_v_hat.sqrt() * grad, alpha=init_lr)
                update.mul_(decay_factor_d)

                b.mul_(1.0 + gamma).add_(gamma)

                if momentum > 0.0:
                    exp_avg = state['exp_avg']
                    exp_avg.mul_(momentum).add_(update, alpha=1.0 - momentum)

                    update.copy_(exp_avg)

                p.add_(-update)

        return loss

get_scale(p) staticmethod

Get expected scale for model weights.

Source code in pytorch_optimizer/optimizer/amos.py
70
71
72
73
74
75
76
77
@staticmethod
def get_scale(p: torch.Tensor) -> float:
    r"""Get expected scale for model weights."""
    if len(p.shape) == 1:  # expected 'bias'
        return 0.5
    if len(p.shape) == 2:  # expected Embedding, Linear, ...
        return math.sqrt(2 / p.size(1))
    return math.sqrt(1 / p.size(1))

Apollo

Bases: Optimizer, BaseOptimizer

An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
init_lr Optional[float]

Optional[float]. initial learning rate (default lr / 1000).

None
beta float

float. coefficient used for computing running averages of gradient.

0.9
rebound str

str. rectified bound for diagonal hessian. (constant, belief).

'constant'
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decay_type str

str. type of weight decay. (l2, decoupled, stable).

'l2'
warmup_steps int

int. number of warmup steps.

500
eps float

float. term added to the denominator to improve numerical stability.

0.0001
Source code in pytorch_optimizer/optimizer/apollo.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class Apollo(Optimizer, BaseOptimizer):
    r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param init_lr: Optional[float]. initial learning rate (default lr / 1000).
    :param beta: float. coefficient used for computing running averages of gradient.
    :param rebound: str. rectified bound for diagonal hessian. (constant, belief).
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decay_type: str. type of weight decay. (l2, decoupled, stable).
    :param warmup_steps: int. number of warmup steps.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        init_lr: Optional[float] = None,
        beta: float = 0.9,
        rebound: str = 'constant',
        weight_decay: float = 0.0,
        weight_decay_type: str = 'l2',
        warmup_steps: int = 500,
        eps: float = 1e-4,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_options(rebound, 'rebound', ['constant', 'belief'])
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_options(weight_decay_type, 'weight_decay_type', ['l2', 'decoupled', 'stable'])
        self.validate_non_negative(eps, 'eps')

        self.lr = lr
        self.warmup_steps = warmup_steps
        self.init_lr: float = init_lr if init_lr is not None else lr / 1000.0

        defaults: DEFAULTS = {
            'lr': lr,
            'init_lr': self.init_lr,
            'beta': beta,
            'rebound': rebound,
            'weight_decay': weight_decay,
            'weight_decay_type': weight_decay_type,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Apollo'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg_grad'] = torch.zeros_like(p)
                state['approx_hessian'] = torch.zeros_like(p)
                state['update'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            current_lr: float = (
                group['lr']
                if group['step'] >= self.warmup_steps
                else (self.lr - group['init_lr']) * group['step'] / self.warmup_steps + group['init_lr']
            )

            weight_decay, eps = group['weight_decay'], group['eps']

            bias_correction: float = 1.0 - group['beta'] ** group['step']
            alpha: float = (1.0 - group['beta']) / bias_correction

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg_grad'] = torch.zeros_like(p)
                    state['approx_hessian'] = torch.zeros_like(p)
                    state['update'] = torch.zeros_like(p)

                if weight_decay > 0.0 and group['weight_decay_type'] == 'l2':
                    grad.add_(p, alpha=weight_decay)

                exp_avg_grad, b, d_p = state['exp_avg_grad'], state['approx_hessian'], state['update']

                delta_grad = grad - exp_avg_grad
                if group['rebound'] == 'belief':
                    rebound = delta_grad.norm(p=np.inf)
                else:
                    rebound = 1e-2
                    eps /= rebound

                exp_avg_grad.add_(delta_grad, alpha=alpha)

                de_nom = d_p.norm(p=4).add_(eps)
                d_p.div_(de_nom)

                v_sq = d_p.mul(d_p)
                delta = delta_grad.div_(de_nom).mul_(d_p).sum().mul(-alpha) - b.mul(v_sq).sum()

                b.addcmul_(v_sq, delta)

                de_nom = b.abs().clamp_(min=rebound)
                if group['rebound'] == 'belief':
                    de_nom.add_(eps / alpha)

                d_p.copy_(exp_avg_grad.div(de_nom))

                if weight_decay > 0.0 and group['weight_decay_type'] != 'l2':
                    if group['weight_decay_type'] == 'stable':
                        weight_decay /= de_nom.mean().item()

                    d_p.add_(p, alpha=weight_decay)

                p.add_(d_p, alpha=-current_lr)

        return loss

AvaGrad

Bases: Optimizer, BaseOptimizer

Domain-independent Dominance of Adaptive Methods.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

0.1
Source code in pytorch_optimizer/optimizer/avagrad.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class AvaGrad(Optimizer, BaseOptimizer):
    r"""Domain-independent Dominance of Adaptive Methods.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-1,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adam_debias': adam_debias,
            'gamma': None,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AvaGrad'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
            prev_bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] - 1))

            squared_norm: float = 0.0
            num_params: float = 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                exp_avg_sq = state['exp_avg_sq']
                sqrt_exp_avg_sq = exp_avg_sq.sqrt()

                if group['step'] > 1:
                    de_nom = sqrt_exp_avg_sq.div(prev_bias_correction2_sq).add_(group['eps'])

                    step_size: float = self.apply_adam_debias(
                        adam_debias=group['adam_debias'],
                        step_size=group['gamma'] * group['lr'],
                        bias_correction1=bias_correction1,
                    )
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)

                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                param_wise_lr = sqrt_exp_avg_sq.div_(bias_correction2_sq).add_(group['eps'])
                squared_norm += param_wise_lr.norm(-2) ** -2
                num_params += param_wise_lr.numel()

            group['gamma'] = 0.0 if num_params == 0.0 else 1.0 / math.sqrt(squared_norm / num_params)

        return loss

CAME

Bases: Optimizer, BaseOptimizer

Confidence-guided Adaptive Memory Efficient Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0002
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999, 0.9999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
clip_threshold float

float. threshold of root-mean-square of final gradient update.

1.0
ams_bound bool

bool. whether to use the AMSBound variant.

False
eps1 float

float. term added to the denominator to improve numerical stability.

1e-30
eps2 float

float. term added to the denominator to improve numerical stability.

1e-16
Source code in pytorch_optimizer/optimizer/came.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class CAME(Optimizer, BaseOptimizer):
    r"""Confidence-guided Adaptive Memory Efficient Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param clip_threshold: float. threshold of root-mean-square of final gradient update.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param eps1: float. term added to the denominator to improve numerical stability.
    :param eps2: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 2e-4,
        betas: BETAS = (0.9, 0.999, 0.9999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        clip_threshold: float = 1.0,
        ams_bound: bool = False,
        eps1: float = 1e-30,
        eps2: float = 1e-16,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps1, 'eps1')
        self.validate_non_negative(eps2, 'eps2')

        self.clip_threshold = clip_threshold
        self.eps1 = eps1
        self.eps2 = eps2

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'eps1': eps1,
            'eps2': eps2,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'CAME'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                grad = p.grad

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                state['exp_avg'] = torch.zeros_like(p)

                if factored:
                    state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_sq_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                    state['exp_avg_res_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                    state['exp_avg_res_col'] = torch.zeros(
                        grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                    )
                else:
                    state['exp_avg_sq'] = torch.zeros_like(grad)

                if group['ams_bound']:
                    state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                state['RMS'] = 0.0

    @staticmethod
    def get_options(shape: Tuple[int, ...]) -> bool:
        r"""Get `factored`."""
        return len(shape) >= 2

    @staticmethod
    def get_rms(x: torch.Tensor) -> float:
        r"""Get RMS."""
        return x.norm(2) / math.sqrt(x.numel())

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: torch.Tensor,
        exp_avg_sq_col: torch.Tensor,
        output: torch.Tensor,
    ):
        r"""Get approximation of EMA of squared gradient."""
        r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2, beta3 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                grad_shape: Tuple[int, ...] = grad.shape
                factored: bool = self.get_options(grad_shape)

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                    if factored:
                        state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                        state['exp_avg_sq_col'] = torch.zeros(
                            grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                        )
                        state['exp_avg_res_row'] = torch.zeros(grad_shape[:-1], dtype=grad.dtype, device=grad.device)
                        state['exp_avg_res_col'] = torch.zeros(
                            grad_shape[:-2] + grad_shape[-1:], dtype=grad.dtype, device=grad.device
                        )
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(grad)

                    if group['ams_bound']:
                        state['exp_avg_sq_hat'] = torch.zeros_like(grad)

                    state['RMS'] = 0.0

                state['RMS'] = self.get_rms(p)

                update = torch.mul(grad, grad).add_(self.eps1)

                if factored:
                    exp_avg_sq_row, exp_avg_sq_col = state['exp_avg_sq_row'], state['exp_avg_sq_col']

                    exp_avg_sq_row.mul_(beta2).add_(update.mean(dim=-1), alpha=1.0 - beta2)
                    exp_avg_sq_col.mul_(beta2).add_(update.mean(dim=-2), alpha=1.0 - beta2)

                    self.approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, update)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(beta2).add_(update, alpha=1.0 - beta2)
                    torch.rsqrt(exp_avg_sq, out=update)

                if group['ams_bound']:
                    exp_avg_sq_hat = state['exp_avg_sq_hat']
                    torch.max(exp_avg_sq_hat, 1 / update, out=exp_avg_sq_hat)
                    torch.rsqrt(exp_avg_sq_hat / beta2, out=update)

                update.mul_(grad)

                update.div_((self.get_rms(update) / self.clip_threshold).clamp_(min=1.0))

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

                res = update - exp_avg
                res.pow_(2).add_(self.eps2)

                if factored:
                    exp_avg_res_row, exp_avg_res_col = state['exp_avg_res_row'], state['exp_avg_res_col']

                    exp_avg_res_row.mul_(beta3).add_(res.mean(dim=-1), alpha=1.0 - beta3)
                    exp_avg_res_col.mul_(beta3).add_(res.mean(dim=-2), alpha=1.0 - beta3)

                    self.approximate_sq_grad(exp_avg_res_row, exp_avg_res_col, update)
                    update.mul_(exp_avg)
                else:
                    update = exp_avg

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                update.mul_(group['lr'])

                p.add_(-update)

        return loss

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/optimizer/came.py
106
107
108
109
110
111
112
113
114
115
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: torch.Tensor,
    exp_avg_sq_col: torch.Tensor,
    output: torch.Tensor,
):
    r"""Get approximation of EMA of squared gradient."""
    r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
    c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
    torch.mul(r_factor, c_factor, out=output)

get_options(shape) staticmethod

Get factored.

Source code in pytorch_optimizer/optimizer/came.py
96
97
98
99
@staticmethod
def get_options(shape: Tuple[int, ...]) -> bool:
    r"""Get `factored`."""
    return len(shape) >= 2

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/optimizer/came.py
101
102
103
104
@staticmethod
def get_rms(x: torch.Tensor) -> float:
    r"""Get RMS."""
    return x.norm(2) / math.sqrt(x.numel())

DAdaptAdaGrad

Bases: Optimizer, BaseOptimizer

AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
momentum float

float. momentum.

0.0
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

0.0
Source code in pytorch_optimizer/optimizer/dadapt.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class DAdaptAdaGrad(Optimizer, BaseOptimizer):
    r"""AdaGrad with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        momentum: float = 0.0,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 0.0,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'k': 0,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdaGrad'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
                state['sk'] = torch.zeros_like(p)
                state['x0'] = torch.clone(p)
                if p.grad.is_sparse:
                    state['weighted_sk'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        g_sq = torch.tensor([0.0], device=device)
        sk_sq_weighted_change = torch.tensor([0.0], device=device)
        sk_l1_change = torch.tensor([0.0], device=device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_sq_weighted' not in group:
            group['sk_sq_weighted'] = torch.tensor([0.0], device=device)
        if 'sk_l1' not in group:
            group['sk_l1'] = torch.tensor([0.0], device=device)

        gsq_weighted = group['gsq_weighted']
        sk_sq_weighted = group['sk_sq_weighted']
        sk_l1 = group['sk_l1']

        for group in self.param_groups:
            eps = group['eps']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]
                if 'alpha_k' not in state:
                    state['alpha_k'] = torch.full_like(p, fill_value=1e-6)
                    state['sk'] = torch.zeros_like(p)
                    state['x0'] = torch.clone(p)
                    if grad.is_sparse:
                        state['weighted_sk'] = torch.zeros_like(p)

                sk, alpha_k = state['sk'], state['alpha_k']

                if grad.is_sparse:
                    weighted_sk = state['weighted_sk']

                    grad = grad.coalesce()

                    vk = grad._values().pow(2)
                    sk_masked = sk.sparse_mask(grad).coalesce()
                    old_sk_l1_masked = sk_masked._values().abs().sum()

                    sk.add_(grad, alpha=d_lr)

                    sk_masked = sk.sparse_mask(grad).coalesce()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()
                    weighted_sk_masked = weighted_sk.sparse_mask(grad).coalesce()

                    # update alpha before step
                    alpha_k_p1_masked = alpha_k_masked._values() + vk

                    alpha_k_delta_masked = alpha_k_p1_masked - alpha_k_masked._values()
                    alpha_k_delta = torch.sparse_coo_tensor(grad.indices(), alpha_k_delta_masked, grad.shape)
                    alpha_k.add_(alpha_k_delta)

                    de_nom = torch.sqrt(alpha_k_p1_masked + eps)

                    grad_sq = vk.div(de_nom).sum()
                    g_sq.add_(grad_sq)

                    # update weighted sk sq tracking
                    weighted_sk_p1_masked = sk_masked._values().pow(2).div(de_nom)

                    sk_sq_weighted_change.add_(weighted_sk_p1_masked.sum() - weighted_sk_masked._values().sum())

                    weighted_sk_p1_delta_masked = weighted_sk_p1_masked - weighted_sk_masked._values()
                    weighted_sk_p1_delta = torch.sparse_coo_tensor(
                        grad.indices(), weighted_sk_p1_delta_masked, grad.shape
                    )
                    weighted_sk.add_(weighted_sk_p1_delta)

                    sk_l1_masked = sk_masked._values().abs().sum()
                    sk_l1_change.add_(sk_l1_masked - old_sk_l1_masked)
                else:
                    self.apply_weight_decay(
                        p=p,
                        grad=grad,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                    old_sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    old_sk_l1_param = sk.abs().sum()

                    alpha_k.add_(grad.pow(2))
                    grad_sq = grad.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    g_sq.add_(grad_sq)

                    sk.add_(grad, alpha=d_lr)

                    sk_sq_weighted_param = sk.pow(2).div(torch.sqrt(alpha_k) + eps).sum()
                    sk_l1_param = sk.abs().sum()

                    sk_sq_weighted_change.add_(sk_sq_weighted_param - old_sk_sq_weighted_param)
                    sk_l1_change.add_(sk_l1_param - old_sk_l1_param)

        sk_sq_weighted.add_(sk_sq_weighted_change)
        gsq_weighted.add_(g_sq, alpha=d_lr ** 2)  # fmt: skip
        sk_l1.add_(sk_l1_change)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat = (sk_sq_weighted - gsq_weighted) / sk_l1
            d = group['d'] = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['gsq_weighted'] = gsq_weighted
            group['sk_sq_weighted'] = sk_sq_weighted
            group['sk_l1'] = sk_l1
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                state = self.state[p]

                alpha_k, sk, x0 = state['alpha_k'], state['sk'], state['x0']

                if grad.is_sparse:
                    grad = grad.coalesce()

                    sk_masked = sk.sparse_mask(grad).coalesce()._values()
                    alpha_k_masked = alpha_k.sparse_mask(grad).coalesce()._values()
                    x0_masked = x0.sparse_mask(grad).coalesce()._values()
                    p_masked = p.sparse_mask(grad).coalesce()._values()

                    loc_masked = x0_masked - sk_masked.div(torch.sqrt(alpha_k_masked + group['eps']))

                    loc_delta_masked = loc_masked - p_masked
                    loc_delta = torch.sparse_coo_tensor(grad.indices(), loc_delta_masked, grad.shape)
                    p.add_(loc_delta)
                else:
                    z = x0 - sk.div(alpha_k.sqrt().add_(group['eps']))

                    if group['momentum'] > 0.0:
                        p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])
                    else:
                        p.copy_(z)

            group['k'] += 1

        return loss

DAdaptAdam

Bases: Optimizer, BaseOptimizer

Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. betas.

(0.9, 0.999)
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use AdamW style weight decay.

False
fixed_decay bool

bool. fix weight decay.

False
bias_correction bool

bool. Turn on Adam's bias correction.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/dadapt.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class DAdaptAdam(Optimizer, BaseOptimizer):
    r"""Adam with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. betas.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use AdamW style weight decay.
    :param fixed_decay: bool. fix weight decay.
    :param bias_correction: bool. Turn on Adam's bias correction.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'step': 0,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdam'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        beta1, beta2 = group['betas']

        beta2_sq: float = math.sqrt(beta2)

        d: float = group['d']
        lr: float = group['lr']

        bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
        bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
        bias_correction: float = bias_correction1 / bias_correction2_sq

        # it's not Adam Debias
        d_lr: float = self.apply_adam_debias(
            not group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
        )

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_acc = torch.tensor([0.0], device=device)

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if 'step' not in state:
                    state['s'] = torch.zeros_like(p)
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                exp_avg, exp_avg_sq, s = state['exp_avg'], state['exp_avg_sq'], state['s']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                numerator_acc.add_(torch.dot(grad.flatten(), s.div(de_nom).flatten()), alpha=d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                s.mul_(beta2_sq).add_(grad, alpha=d_lr * (1.0 - beta2_sq))

                sk_l1.add_(s.abs().sum())

        if sk_l1 == 0:
            return loss

        numerator_weighted.mul_(beta2_sq).add_(numerator_acc, alpha=1.0 - beta2_sq)  # fmt: skip

        if lr > 0.0:
            d_hat = numerator_weighted / (1.0 - beta2_sq) * sk_l1
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.addcdiv_(exp_avg, de_nom, value=-1.0)

        return loss

DAdaptSGD

Bases: Optimizer, BaseOptimizer

SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
momentum float

float. momentum.

0.9
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
class DAdaptSGD(Optimizer, BaseOptimizer):
    r"""SGD with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        momentum: float = 0.9,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'd': d0,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptSGD'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['z'] = p.clone()
                state['s'] = torch.zeros_like(p)
                state['x0'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        sk_sq = torch.tensor([0.0], device=device)
        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        if group['step'] == 0:
            group['g0_norm'] = get_global_gradient_norm(self.param_groups, device).sqrt_().item()
        g0_norm = group['g0_norm']

        if g0_norm == 0:
            return loss

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr / g0_norm

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['z'] = p.clone()
                    state['s'] = torch.zeros_like(p)
                    state['x0'] = p.clone()

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s = state['s']
                numerator_weighted.add_(torch.dot(grad.flatten(), s.flatten()), alpha=d_lr)

                s.add_(grad, alpha=d_lr)
                sk_sq.add_(s.pow(2).sum())

        if lr > 0.0:
            d_hat = 2.0 * numerator_weighted / sk_sq.sqrt()
            d = max(d, min(d_hat.item(), d * group['growth_rate']))

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                z = state['z']
                z.copy_(state['x0'] - state['s'])

                p.mul_(group['momentum']).add_(z, alpha=1.0 - group['momentum'])

        return loss

DAdaptAdan

Bases: Optimizer, BaseOptimizer

Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.98, 0.92, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. decoupled weight decay.

False
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate. Default is inf, for unrestricted.

float('inf')
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/dadapt.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
class DAdaptAdan(Optimizer, BaseOptimizer):
    r"""Adan with D-Adaptation. Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. decoupled weight decay.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
        Default is inf, for unrestricted.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.98, 0.92, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        d0: float = 1e-6,
        growth_rate: float = float('inf'),
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'd': d0,
            'growth_rate': growth_rate,
            'k': 0,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptAdan'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['step'] = 0
                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg_diff'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]

        beta1, beta2, beta3 = group['betas']
        growth_rate = group['growth_rate']

        d, lr = group['d'], group['lr']
        d_lr = float(d * lr)

        g_sq = torch.tensor([0.0], device=group['params'][0].device)
        sk_sq_weighted = torch.tensor([0.0], device=group['params'][0].device)
        sk_l1 = torch.tensor([0.0], device=group['params'][0].device)
        if 'gsq_weighted' not in group:
            group['gsq_weighted'] = torch.tensor([0.0], device=group['params'][0].device)
        gsq_weighted = group['gsq_weighted']

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if 'step' not in state:
                    state['step'] = 0

                    state['s'] = torch.zeros_like(p)
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['exp_avg_diff'] = torch.zeros_like(p)
                    state['previous_grad'] = -grad.clone()

                grad_diff = state['previous_grad']
                grad_diff.add_(grad)

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                exp_avg.mul_(beta1).add_(grad, alpha=d_lr * (1.0 - beta1))
                exp_avg_diff.mul_(beta2).add_(grad_diff, alpha=d_lr * (1.0 - beta2))

                grad_diff.mul_(beta2).add_(grad)
                grad_diff = to_real(grad_diff * grad_diff.conj())
                exp_avg_sq.mul_(beta3).addcmul_(grad_diff, grad_diff, value=1.0 - beta3)

                grad_power = to_real(grad * grad.conj())
                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                g_sq.add_(grad_power.div_(de_nom).sum())

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=d_lr * (1.0 - beta3))

                sk_sq_weighted.add_(to_real(s * s.conj()).div_(de_nom).sum())
                sk_l1.add_(s.abs().sum())

                state['previous_grad'].copy_(-grad)

        if sk_l1 == 0:
            return loss

        gsq_weighted.mul_(beta3).add_(g_sq, alpha=(d_lr ** 2) * (1.0 - beta3))  # fmt: skip

        if lr > 0.0:
            d_hat = (sk_sq_weighted / (1.0 - beta3) - gsq_weighted) / sk_l1
            d = max(d, min(d_hat, d * growth_rate))

        for group in self.param_groups:
            group['gsq_weighted'] = gsq_weighted
            group['d'] = d
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['step'] += 1

                exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_diff']

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                if group['weight_decouple']:
                    p.mul_(1.0 - d_lr * group['weight_decay'])

                p.addcdiv_(exp_avg, de_nom, value=-1.0)
                p.addcdiv_(exp_avg_diff, de_nom, value=-beta2)

                if not group['weight_decouple']:
                    p.div_(1.0 + d_lr * group['weight_decay'])

            group['k'] += 1

        return loss

DAdaptLion

Bases: Optimizer, BaseOptimizer

Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/dadapt.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
class DAdaptLion(Optimizer, BaseOptimizer):
    r"""Lion with D-Adaptation. Leave LR set to 1 unless you encounter instability. This implementation is based on V3.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        d0: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'd': d0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'step': 0,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'DAdaptLion'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        if 'numerator_weighted' not in group:
            group['numerator_weighted'] = torch.tensor([0.0], device=device)
        numerator_weighted = group['numerator_weighted']

        sk_l1 = torch.tensor([0.0], device=device)
        numerator_accumulator = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta2_sq = math.sqrt(beta2)

        d, lr = group['d'], group['lr']
        d_lr: float = d * lr

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['s'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, s = state['exp_avg'], state['s']

                update = exp_avg.clone().mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
                p.add_(update, alpha=-d_lr)

                exp_avg.mul_(beta2).add_(grad, alpha=(1.0 - beta2) * d_lr)

                numerator_accumulator.add_(torch.dot(update.flatten(), s.flatten()), alpha=d_lr)
                s.mul_(beta2_sq).add_(update, alpha=(1.0 - beta2_sq) * d_lr)

                sk_l1.add_(s.abs().sum())

        numerator_weighted.mul_(beta2_sq).add_(numerator_accumulator, alpha=1.0 - beta2_sq)

        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            d_hat: float = (numerator_weighted / ((1.0 - beta2_sq) * sk_l1)).item()
            d = max(d, d_hat)

        for group in self.param_groups:
            group['step'] += 1

            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

        return loss

DiffGrad

Bases: Optimizer, BaseOptimizer

An Optimization Method for Convolutional Neural Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

bool. degenerated to SGD.

True
ams_bound bool

bool. whether to use the AMSBound variant.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/diffgrad.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class DiffGrad(Optimizer, BaseOptimizer):
    r"""An Optimization Method for Convolutional Neural Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: bool. degenerated to SGD.
    :param ams_bound: bool. whether to use the AMSBound variant.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = True,
        ams_bound: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'ams_bound': ams_bound,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'diffGrad'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['previous_grad'] = torch.zeros_like(p)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['previous_grad'] = torch.zeros_like(p)
                    if group['ams_bound']:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                # compute diffGrad coefficient (dfc)
                dfc = state['previous_grad'].clone()
                dfc.sub_(grad).abs_().sigmoid_().mul_(exp_avg)
                state['previous_grad'].copy_(grad)

                if not group['rectify']:
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                    continue

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if n_sma >= self.n_sma_threshold:
                    p.addcdiv_(dfc, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

DynamicLossScaler

Dynamically adjusts the loss scaling factor.

Dynamic loss scalers are important in mixed-precision training.
They help us avoid underflows and overflows in low-precision gradients.

See here for information:
<https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>

Shamelessly stolen and adapted from FairSeq.
<https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>

Reference : 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

Parameters:

Name Type Description Default
init_scale float

Initial loss scale.

2.0 ** 15
scale_factor float

Factor by which to increase or decrease loss scale.

2.0
scale_window int

If we do not experience overflow in scale_window iterations, loss scale will increase by scale_factor.

2000
tolerance float

Pct of iterations that have overflowed after which we must decrease the loss scale.

0.0
threshold Optional[float]

If not None, loss scale will decrease below this threshold.

None
Source code in pytorch_optimizer/optimizer/fp16.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class DynamicLossScaler:
    r"""Dynamically adjusts the loss scaling factor.

        Dynamic loss scalers are important in mixed-precision training.
        They help us avoid underflows and overflows in low-precision gradients.

        See here for information:
        <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#lossscaling>

        Shamelessly stolen and adapted from FairSeq.
        <https://github.com/pytorch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py>

        Reference : 'https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/fp16.py'

    :param init_scale: Initial loss scale.
    :param scale_factor: Factor by which to increase or decrease loss scale.
    :param scale_window: If we do not experience overflow in scale_window iterations,
        loss scale will increase by scale_factor.
    :param tolerance: Pct of iterations that have overflowed after which we must decrease the loss scale.
    :param threshold: If not None, loss scale will decrease below this threshold.
    """

    def __init__(
        self,
        init_scale: float = 2.0 ** 15,
        scale_factor: float = 2.0,
        scale_window: int = 2000,
        tolerance: float = 0.00,
        threshold: Optional[float] = None,
    ):  # fmt: skip
        self.loss_scale = init_scale
        self.scale_factor = scale_factor
        self.scale_window = scale_window
        self.tolerance = tolerance
        self.threshold = threshold

        self.iter: int = 0
        self.last_overflow_iter: int = -1
        self.last_rescale_iter: int = -1
        self.overflows_since_rescale: int = 0
        self.has_overflow_serial: bool = False

    def update_scale(self, overflow: bool):
        r"""Update the loss scale.

            If overflow exceeds our tolerance, we decrease the loss scale.
            If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

        :param overflow: bool. adjust scales to prevent overflow.
        """
        iter_since_rescale: int = self.iter - self.last_rescale_iter

        if overflow:
            # calculate how often we overflowed already
            self.last_overflow_iter = self.iter
            self.overflows_since_rescale += 1

            pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
            if pct_overflow >= self.tolerance:
                # decrease loss scale by the scale factor
                self.decrease_loss_scale()

                # reset iterations
                self.last_rescale_iter = self.iter
                self.overflows_since_rescale = 0
        elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
            # increase the loss scale by scale factor
            self.loss_scale *= self.scale_factor
            self.last_rescale_iter = self.iter

        self.iter += 1

    def decrease_loss_scale(self):
        r"""Decrease the loss scale by self.scale_factor.

        NOTE: the loss_scale will not go below `self.threshold`.
        """
        self.loss_scale /= self.scale_factor
        if self.threshold is not None:
            self.loss_scale = max(self.loss_scale, self.threshold)

decrease_loss_scale()

Decrease the loss scale by self.scale_factor.

NOTE: the loss_scale will not go below self.threshold.

Source code in pytorch_optimizer/optimizer/fp16.py
83
84
85
86
87
88
89
90
def decrease_loss_scale(self):
    r"""Decrease the loss scale by self.scale_factor.

    NOTE: the loss_scale will not go below `self.threshold`.
    """
    self.loss_scale /= self.scale_factor
    if self.threshold is not None:
        self.loss_scale = max(self.loss_scale, self.threshold)

update_scale(overflow)

Update the loss scale.

If overflow exceeds our tolerance, we decrease the loss scale.
If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

Parameters:

Name Type Description Default
overflow bool

bool. adjust scales to prevent overflow.

required
Source code in pytorch_optimizer/optimizer/fp16.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def update_scale(self, overflow: bool):
    r"""Update the loss scale.

        If overflow exceeds our tolerance, we decrease the loss scale.
        If the number of iterations since the last overflow exceeds the scale window, we increase the loss scale.

    :param overflow: bool. adjust scales to prevent overflow.
    """
    iter_since_rescale: int = self.iter - self.last_rescale_iter

    if overflow:
        # calculate how often we overflowed already
        self.last_overflow_iter = self.iter
        self.overflows_since_rescale += 1

        pct_overflow: float = self.overflows_since_rescale / float(iter_since_rescale)
        if pct_overflow >= self.tolerance:
            # decrease loss scale by the scale factor
            self.decrease_loss_scale()

            # reset iterations
            self.last_rescale_iter = self.iter
            self.overflows_since_rescale = 0
    elif (self.iter - self.last_overflow_iter) % self.scale_window == 0:
        # increase the loss scale by scale factor
        self.loss_scale *= self.scale_factor
        self.last_rescale_iter = self.iter

    self.iter += 1

SafeFP16Optimizer

Bases: Optimizer

Safe FP16 Optimizer.

Parameters:

Name Type Description Default
optimizer OPTIMIZER

OPTIMIZER.

required
aggregate_g_norms bool

bool. aggregate_g_norms.

False
min_loss_scale float

float. min_loss_scale.

2 ** -5
Source code in pytorch_optimizer/optimizer/fp16.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
class SafeFP16Optimizer(Optimizer):  # pragma: no cover
    r"""Safe FP16 Optimizer.

    :param optimizer: OPTIMIZER.
    :param aggregate_g_norms: bool. aggregate_g_norms.
    :param min_loss_scale: float. min_loss_scale.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER,
        aggregate_g_norms: bool = False,
        min_loss_scale: float = 2 ** -5,
    ):  # fmt: skip
        self.optimizer = optimizer
        self.aggregate_g_norms = aggregate_g_norms
        self.min_loss_scale = min_loss_scale

        self.fp16_params = self.get_parameters(optimizer)
        self.fp32_params = self.build_fp32_params(self.fp16_params, flatten=False)

        # we want the optimizer to be tracking the fp32 parameters
        if len(optimizer.param_groups) != 1:
            # future implementers: this should hopefully be a matter of just
            # iterating through the param groups and keeping track of the pointer
            # through the fp32_params
            raise NotImplementedError('[-] Need to implement the parameter group transfer.')

        optimizer.param_groups[0]['params'] = self.fp32_params

        self.scaler: DynamicLossScaler = DynamicLossScaler(2.0 ** 15)  # fmt: skip
        self.needs_sync: bool = True

    @classmethod
    def get_parameters(cls, optimizer: OPTIMIZER):
        params: List = []
        for pg in optimizer.param_groups:
            params += list(pg['params'])
        return params

    @classmethod
    def build_fp32_params(
        cls, parameters: PARAMETERS, flatten: bool = True
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        # create FP32 copy of parameters and grads
        if flatten:
            total_param_size: int = sum(p.numel() for p in parameters)
            fp32_params = torch.zeros(total_param_size, dtype=torch.float, device=parameters[0].device)

            offset: int = 0
            for p in parameters:
                p_num_el = p.numel()
                fp32_params[offset : offset + p_num_el].copy_(p.view(-1))
                offset += p_num_el

            fp32_params = nn.Parameter(fp32_params)
            fp32_params.grad = fp32_params.new(total_param_size)

            return fp32_params

        fp32_params: List = []
        for p in parameters:
            p32 = nn.Parameter(p.float())
            p32.grad = torch.zeros_like(p32)
            fp32_params.append(p32)

        return fp32_params

    def state_dict(self) -> Dict:
        r"""Return the optimizer state dict."""
        state_dict = self.optimizer.state_dict()
        if self.scaler is not None:
            state_dict['loss_scaler'] = self.scaler.loss_scale
        return state_dict

    def load_state_dict(self, state_dict: Dict):
        r"""Load an optimizer state dict.

            In general, we should prefer the configuration of the existing optimizer instance
            (e.g., learning rate) over that found in the state_dict. This allows us to
            resume training from a checkpoint using a new set of optimizer args.

        :param state_dict: Dict. state_dict.
        """
        if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
            self.scaler.loss_scale = state_dict['loss_scaler']
        self.optimizer.load_state_dict(state_dict)

    def backward(self, loss, update_main_grads: bool = False):
        r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.

            Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
            additionally dynamically scales the loss to avoid gradient underflow.

        :param loss: float. loss.
        :param update_main_grads: bool. update main gradient.
        """
        if self.scaler is not None:
            loss = loss * self.scaler.loss_scale

        loss.backward()

        self.needs_sync = True
        if update_main_grads:
            self.update_main_grads()

    def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
        r"""Sync fp16 to fp32 gradients."""
        if self.needs_sync:
            if self.scaler is not None:
                # correct for dynamic loss scaler
                multiply_grads /= self.scaler.loss_scale

            # copy FP16 grads to FP32
            for p, p32 in zip(self.fp16_params, self.fp32_params):
                if not p.requires_grad:
                    continue

                if p.grad is not None:
                    p32.grad.copy_(p.grad)
                    p32.grad.mul_(multiply_grads)
                else:
                    p32.grad = torch.zeros_like(p, dtype=torch.float)

            self.needs_sync = False

    def multiply_grads(self, c: float):
        r"""Multiply grads by a constant c."""
        if self.needs_sync:
            self.sync_fp16_grads_to_fp32(c)
        else:
            for p32 in self.fp32_params:
                p32.grad.mul_(c)

    def update_main_grads(self):
        self.sync_fp16_grads_to_fp32()

    def clip_main_grads(self, max_norm: float):
        r"""Clip gradient norm and updates dynamic loss scaler."""
        self.sync_fp16_grads_to_fp32()

        grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

        # detect overflow and adjust loss scale
        if self.scaler is not None:
            overflow: bool = has_overflow(grad_norm)
            prev_scale: float = self.scaler.loss_scale

            self.scaler.update_scale(overflow)

            if overflow:
                self.zero_grad()
                if self.scaler.loss_scale <= self.min_loss_scale:
                    # Use FloatingPointError as an uncommon error
                    # that parent functions can safely catch to stop training.
                    self.scaler.loss_scale = prev_scale

                    raise FloatingPointError(
                        f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                        'Try lowering the learning rate, using gradient clipping or '
                        'increasing the batch size.\n'
                        f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                    )

        return grad_norm

    def step(self, closure: CLOSURE = None):
        r"""Perform a single optimization step."""
        self.sync_fp16_grads_to_fp32()
        self.optimizer.step(closure)

        # copy FP32 params back into FP16 model
        for p, p32 in zip(self.fp16_params, self.fp32_params):
            if not p.requires_grad:
                continue
            p.data.copy_(p32)

    def zero_grad(self):
        r"""Clear the gradients of all optimized parameters."""
        for p in self.fp16_params:
            p.grad = None
        for p32 in self.fp32_params:
            p32.grad.zero_()
        self.needs_sync = False

    def get_lr(self) -> float:
        r"""Get learning rate."""
        return self.optimizer.get_lr()

    def set_lr(self, lr: float):
        r"""Set learning rate."""
        self.optimizer.set_lr(lr)

    @property
    def loss_scale(self) -> float:
        r"""Convenience function which TorchAgent calls to get current scale value."""
        return self.scaler.loss_scale

loss_scale: float property

Convenience function which TorchAgent calls to get current scale value.

backward(loss, update_main_grads=False)

Compute the sum of gradients of the given tensor w.r.t. graph leaves.

Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
additionally dynamically scales the loss to avoid gradient underflow.

Parameters:

Name Type Description Default
loss

float. loss.

required
update_main_grads bool

bool. update main gradient.

False
Source code in pytorch_optimizer/optimizer/fp16.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def backward(self, loss, update_main_grads: bool = False):
    r"""Compute the sum of gradients of the given tensor w.r.t. graph leaves.

        Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this function
        additionally dynamically scales the loss to avoid gradient underflow.

    :param loss: float. loss.
    :param update_main_grads: bool. update main gradient.
    """
    if self.scaler is not None:
        loss = loss * self.scaler.loss_scale

    loss.backward()

    self.needs_sync = True
    if update_main_grads:
        self.update_main_grads()

clip_main_grads(max_norm)

Clip gradient norm and updates dynamic loss scaler.

Source code in pytorch_optimizer/optimizer/fp16.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def clip_main_grads(self, max_norm: float):
    r"""Clip gradient norm and updates dynamic loss scaler."""
    self.sync_fp16_grads_to_fp32()

    grad_norm = clip_grad_norm(self.fp32_params, max_norm, sync=self.aggregate_g_norms)

    # detect overflow and adjust loss scale
    if self.scaler is not None:
        overflow: bool = has_overflow(grad_norm)
        prev_scale: float = self.scaler.loss_scale

        self.scaler.update_scale(overflow)

        if overflow:
            self.zero_grad()
            if self.scaler.loss_scale <= self.min_loss_scale:
                # Use FloatingPointError as an uncommon error
                # that parent functions can safely catch to stop training.
                self.scaler.loss_scale = prev_scale

                raise FloatingPointError(
                    f'Minimum loss scale reached ({self.min_loss_scale}). Your loss is probably exploding. '
                    'Try lowering the learning rate, using gradient clipping or '
                    'increasing the batch size.\n'
                    f'Overflow: setting loss scale to {self.scaler.loss_scale}'
                )

    return grad_norm

get_lr()

Get learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
278
279
280
def get_lr(self) -> float:
    r"""Get learning rate."""
    return self.optimizer.get_lr()

load_state_dict(state_dict)

Load an optimizer state dict.

In general, we should prefer the configuration of the existing optimizer instance
(e.g., learning rate) over that found in the state_dict. This allows us to
resume training from a checkpoint using a new set of optimizer args.

Parameters:

Name Type Description Default
state_dict Dict

Dict. state_dict.

required
Source code in pytorch_optimizer/optimizer/fp16.py
168
169
170
171
172
173
174
175
176
177
178
179
def load_state_dict(self, state_dict: Dict):
    r"""Load an optimizer state dict.

        In general, we should prefer the configuration of the existing optimizer instance
        (e.g., learning rate) over that found in the state_dict. This allows us to
        resume training from a checkpoint using a new set of optimizer args.

    :param state_dict: Dict. state_dict.
    """
    if 'loss_scaler' in state_dict and self.scaler is not None and isinstance(state_dict['loss_scaler'], float):
        self.scaler.loss_scale = state_dict['loss_scaler']
    self.optimizer.load_state_dict(state_dict)

multiply_grads(c)

Multiply grads by a constant c.

Source code in pytorch_optimizer/optimizer/fp16.py
219
220
221
222
223
224
225
def multiply_grads(self, c: float):
    r"""Multiply grads by a constant c."""
    if self.needs_sync:
        self.sync_fp16_grads_to_fp32(c)
    else:
        for p32 in self.fp32_params:
            p32.grad.mul_(c)

set_lr(lr)

Set learning rate.

Source code in pytorch_optimizer/optimizer/fp16.py
282
283
284
def set_lr(self, lr: float):
    r"""Set learning rate."""
    self.optimizer.set_lr(lr)

state_dict()

Return the optimizer state dict.

Source code in pytorch_optimizer/optimizer/fp16.py
161
162
163
164
165
166
def state_dict(self) -> Dict:
    r"""Return the optimizer state dict."""
    state_dict = self.optimizer.state_dict()
    if self.scaler is not None:
        state_dict['loss_scaler'] = self.scaler.loss_scale
    return state_dict

step(closure=None)

Perform a single optimization step.

Source code in pytorch_optimizer/optimizer/fp16.py
259
260
261
262
263
264
265
266
267
268
def step(self, closure: CLOSURE = None):
    r"""Perform a single optimization step."""
    self.sync_fp16_grads_to_fp32()
    self.optimizer.step(closure)

    # copy FP32 params back into FP16 model
    for p, p32 in zip(self.fp16_params, self.fp32_params):
        if not p.requires_grad:
            continue
        p.data.copy_(p32)

sync_fp16_grads_to_fp32(multiply_grads=1.0)

Sync fp16 to fp32 gradients.

Source code in pytorch_optimizer/optimizer/fp16.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def sync_fp16_grads_to_fp32(self, multiply_grads: float = 1.0):
    r"""Sync fp16 to fp32 gradients."""
    if self.needs_sync:
        if self.scaler is not None:
            # correct for dynamic loss scaler
            multiply_grads /= self.scaler.loss_scale

        # copy FP16 grads to FP32
        for p, p32 in zip(self.fp16_params, self.fp32_params):
            if not p.requires_grad:
                continue

            if p.grad is not None:
                p32.grad.copy_(p.grad)
                p32.grad.mul_(multiply_grads)
            else:
                p32.grad = torch.zeros_like(p, dtype=torch.float)

        self.needs_sync = False

zero_grad()

Clear the gradients of all optimized parameters.

Source code in pytorch_optimizer/optimizer/fp16.py
270
271
272
273
274
275
276
def zero_grad(self):
    r"""Clear the gradients of all optimized parameters."""
    for p in self.fp16_params:
        p.grad = None
    for p32 in self.fp32_params:
        p32.grad.zero_()
    self.needs_sync = False

Fromage

Bases: Optimizer, BaseOptimizer

On the distance between two neural networks and the stability of learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
p_bound Optional[float]

Optional[float]. Restricts the optimisation to a bounded set. A value of 2.0 restricts parameter norms to lie within 2x their initial norms. This regularises the model class.

None
Source code in pytorch_optimizer/optimizer/fromage.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class Fromage(Optimizer, BaseOptimizer):
    r"""On the distance between two neural networks and the stability of learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param p_bound: Optional[float]. Restricts the optimisation to a bounded set. A value of 2.0 restricts parameter
        norms to lie within 2x their initial norms. This regularises the model class.
    """

    def __init__(self, params: PARAMETERS, lr: float = 1e-2, p_bound: Optional[float] = None):
        self.validate_learning_rate(lr)

        self.p_bound = p_bound

        defaults: DEFAULTS = {'lr': lr}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Fromage'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                if self.p_bound is not None:
                    state['max'] = p.norm().mul_(self.p_bound)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            pre_factor: float = math.sqrt(1 + group['lr'] ** 2)
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0 and self.p_bound is not None:
                    state['max'] = p.norm().mul_(self.p_bound)

                p_norm, g_norm = p.norm(), grad.norm()

                if p_norm > 0.0 and g_norm > 0.0:
                    p.add_(grad * (p_norm / g_norm), alpha=-group['lr'])
                else:
                    p.add_(grad, alpha=-group['lr'])

                p.div_(pre_factor)

                if self.p_bound is not None:
                    p_norm = p.norm()
                    if p_norm > state['max']:
                        p.mul_(state['max']).div_(p_norm)

        return loss

GaLoreProjector

Memory-Efficient LLM Training by Gradient Low-Rank Projection.

Parameters:

Name Type Description Default
rank int

int. low rank to project.

128
update_proj_gap int

int. num steps to update the projection.

50
scale float

float. scale factor.

1.0
projection_type PROJECTION_TYPE

PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are supported.

'std'
Source code in pytorch_optimizer/optimizer/galore.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class GaLoreProjector:
    r"""Memory-Efficient LLM Training by Gradient Low-Rank Projection.

    :param rank: int. low rank to project.
    :param update_proj_gap: int. num steps to update the projection.
    :param scale: float. scale factor.
    :param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
        supported.
    """

    def __init__(
        self, rank: int = 128, update_proj_gap: int = 50, scale: float = 1.0, projection_type: PROJECTION_TYPE = 'std'
    ):
        self.rank = rank
        self.update_proj_gap = update_proj_gap
        self.scale = scale
        self.projection_type = projection_type

        self.ortho_matrix: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None

    @staticmethod
    def get_orthogonal_matrix(
        weights: torch.Tensor, rank: int, projection_type: str
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if projection_type not in {'right', 'left', 'full'}:
            raise ValueError('projection_type should be one of left, right or full')

        original_type = weights.data.dtype
        original_device = weights.data.device
        is_float: bool = original_type == torch.float

        u, s, vh = torch.linalg.svd(weights if is_float else weights.float(), full_matrices=False)

        if projection_type == 'right':
            b = vh[:rank, :]
            return b if is_float else b.to(original_device).type(original_type)
        if projection_type == 'left':
            a = u[:, :rank]
            return a if is_float else a.to(original_device).type(original_type)

        a = u[:, :rank]
        b = vh[:rank, :]

        return (
            (a, b)
            if is_float
            else (a.to(original_device).type(original_type), b.to(original_device).type(original_type))
        )

    def get_low_rank_grad_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
        if grad.shape[0] >= grad.shape[1]:
            if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
                self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
            return torch.matmul(grad, self.ortho_matrix.t())

        if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
            self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')

        return torch.matmul(self.ortho_matrix.t(), grad)

    def get_low_rank_grad_reverse_std(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
        if grad.shape[0] >= grad.shape[1]:
            if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
                self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
            return torch.matmul(self.ortho_matrix.t(), grad)

        if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
            self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')

        return torch.matmul(grad, self.ortho_matrix.t())

    def get_low_rank_grad_right(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
        if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
            self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='right')
        return torch.matmul(grad, self.ortho_matrix.t())

    def get_low_rank_grad_left(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
        if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
            self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='left')
        return torch.matmul(self.ortho_matrix.t(), grad)

    def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
        if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
            self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
        return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()

    def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
        if self.projection_type == 'std':
            return self.get_low_rank_grad_std(full_rank_grad, steps)
        if self.projection_type == 'reverse_std':
            return self.get_low_rank_grad_reverse_std(full_rank_grad, steps)
        if self.projection_type == 'right':
            return self.get_low_rank_grad_right(full_rank_grad, steps)
        if self.projection_type == 'left':
            return self.get_low_rank_grad_left(full_rank_grad, steps)
        if self.projection_type == 'full':
            return self.get_low_rank_grad_full(full_rank_grad, steps)
        raise NotImplementedError

    def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
        if self.projection_type == 'std':
            return (
                torch.matmul(low_rank_grad, self.ortho_matrix)
                if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
                else torch.matmul(self.ortho_matrix, low_rank_grad)
            ) * self.scale
        if self.projection_type == 'reverse_std':
            return (
                torch.matmul(self.ortho_matrix, low_rank_grad.t())
                if low_rank_grad.shape[0] <= low_rank_grad.shape[1]
                else torch.matmul(low_rank_grad, self.ortho_matrix.t())
            ) * self.scale
        if self.projection_type == 'right':
            return torch.matmul(low_rank_grad, self.ortho_matrix.t()) * self.scale
        if self.projection_type == 'left':
            return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
        if self.projection_type == 'full':
            return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale

        raise NotImplementedError

centralize_gradient(x, gc_conv_only=False)

Gradient Centralization (GC).

Parameters:

Name Type Description Default
x Tensor

torch.Tensor. gradient.

required
gc_conv_only bool

bool. 'False' for both conv & fc layers.

False
Source code in pytorch_optimizer/optimizer/gc.py
 4
 5
 6
 7
 8
 9
10
11
12
def centralize_gradient(x: torch.Tensor, gc_conv_only: bool = False):
    r"""Gradient Centralization (GC).

    :param x: torch.Tensor. gradient.
    :param gc_conv_only: bool. 'False' for both conv & fc layers.
    """
    size: int = x.dim()
    if (gc_conv_only and size > 3) or (not gc_conv_only and size > 1):
        x.add_(-x.mean(dim=tuple(range(1, size)), keepdim=True))

Gravity

Bases: Optimizer, BaseOptimizer

a Kinematic Approach on Optimization in Deep Learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
alpha float

float. alpha controls the V initialization.

0.01
beta float

float. beta will be used to compute running average of V.

0.9
Source code in pytorch_optimizer/optimizer/gravity.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class Gravity(Optimizer, BaseOptimizer):
    r"""a Kinematic Approach on Optimization in Deep Learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param alpha: float. alpha controls the V initialization.
    :param beta: float. beta will be used to compute running average of V.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        alpha: float = 0.01,
        beta: float = 0.9,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        defaults: DEFAULTS = {'lr': lr, 'alpha': alpha, 'beta': beta}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Gravity'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['v'] = torch.empty_like(p).normal_(mean=0.0, std=group['alpha'] / group['lr'])

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta_t: float = (group['beta'] * group['step'] + 1) / (group['step'] + 2)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['v'] = torch.empty_like(p).normal_(mean=0.0, std=group['alpha'] / group['lr'])

                v = state['v']

                m = 1.0 / grad.abs().max()
                zeta = grad / (1.0 + (grad / m) ** 2)

                v.mul_(beta_t).add_(zeta, alpha=1.0 - beta_t)

                p.add_(v, alpha=-group['lr'])

        return loss

GSAM

Bases: Optimizer, BaseOptimizer

Surrogate Gap Guided Sharpness-Aware Minimization.

Example:

Here's an example::

    model = YourModel()
    base_optimizer = AdamP(model.parameters())
    lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
    rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
    optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

    def loss_fn(predictions, targets):
        return F.cross_entropy(predictions, targets)

    for inputs, targets in data:
        optimizer.set_closure(loss_fn, inputs, targets)
        predictions, loss = optimizer.step()
        lr_scheduler.step()
        optimizer.update_rho_t()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

Optimizer. base optimizer.

required
model Module

nn.Module. model.

required
alpha float

float. rho alpha.

0.4
rho_scheduler

rho scheduler.

required
adaptive bool

bool. element-wise Adaptive SAM.

False
perturb_eps float

float. epsilon for perturbation.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class GSAM(Optimizer, BaseOptimizer):  # pragma: no cover
    r"""Surrogate Gap Guided Sharpness-Aware Minimization.

    Example:
    -------
        Here's an example::

            model = YourModel()
            base_optimizer = AdamP(model.parameters())
            lr_scheduler = LinearScheduler(base_optimizer, t_max=num_total_steps)
            rho_scheduler = ProportionScheduler(lr_scheduler, max_lr=max_lr)
            optimizer = GSAM(model.parameters(), base_optimizer, model, rho_scheduler)

            def loss_fn(predictions, targets):
                return F.cross_entropy(predictions, targets)

            for inputs, targets in data:
                optimizer.set_closure(loss_fn, inputs, targets)
                predictions, loss = optimizer.step()
                lr_scheduler.step()
                optimizer.update_rho_t()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: Optimizer. base optimizer.
    :param model: nn.Module. model.
    :param alpha: float. rho alpha.
    :param rho_scheduler: rho scheduler.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param perturb_eps: float. epsilon for perturbation.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        model: nn.Module,
        rho_scheduler,
        alpha: float = 0.4,
        adaptive: bool = False,
        perturb_eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_range(alpha, 'alpha', 0.0, 1.0)

        self.model = model
        self.rho_scheduler = rho_scheduler
        self.alpha = alpha
        self.adaptive = adaptive
        self.perturb_eps = perturb_eps

        self.rho_t: float = 0.0
        self.forward_backward_func: Optional[Callable] = None

        if hasattr(ReduceOp, 'AVG'):
            self.grad_reduce = ReduceOp.AVG
            self.manual_average: bool = False
        else:  # PyTorch <= 1.11.0 does not have AVG, need to manually average across processes
            self.grad_reduce = ReduceOp.SUM
            self.manual_average: bool = True

        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups

        defaults: DEFAULTS = {'adaptive': adaptive}
        defaults.update(kwargs)
        super().__init__(params, defaults)

        self.update_rho_t()

    def __str__(self) -> str:
        return 'GSAM'

    @torch.no_grad()
    def reset(self):
        pass

    @torch.no_grad()
    def update_rho_t(self) -> float:
        self.rho_t = self.rho_scheduler.step()
        return self.rho_t

    @torch.no_grad()
    def perturb_weights(self, rho: float):
        grad_norm = self.grad_norm(weight_adaptive=self.adaptive)
        for group in self.param_groups:
            scale = rho / (grad_norm + self.perturb_eps)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['old_g'] = p.grad.clone()

                e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

                self.state[p]['e_w'] = e_w

    @torch.no_grad()
    def un_perturb(self):
        for group in self.param_groups:
            for p in group['params']:
                if 'e_w' in self.state[p]:
                    p.sub_(self.state[p]['e_w'])

    @torch.no_grad()
    def gradient_decompose(self, alpha: float = 0.0):
        inner_prod = 0.0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                inner_prod += torch.sum(self.state[p]['old_g'] * p.grad)

        new_grad_norm = self.grad_norm(by=None)
        old_grad_norm = self.grad_norm(by='old_g')

        cosine = inner_prod / (new_grad_norm * old_grad_norm + self.perturb_eps)

        # gradient decomposition
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                vertical = self.state[p]['old_g'] - cosine * old_grad_norm * p.grad / (
                    new_grad_norm + self.perturb_eps
                )
                p.grad.add_(vertical, alpha=-alpha)

    @torch.no_grad()
    def sync_grad(self):
        if is_initialized():
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    all_reduce(p.grad, op=self.grad_reduce)
                    if self.manual_average:
                        p.grad.div_(float(get_world_size()))

    @torch.no_grad()
    def grad_norm(self, by: Optional[str] = None, weight_adaptive: bool = False) -> torch.Tensor:
        return torch.norm(
            torch.stack(
                [
                    ((torch.abs(p) if weight_adaptive else 1.0) * (p.grad if not by else self.state[p][by])).norm(p=2)
                    for group in self.param_groups
                    for p in group['params']
                    if p.grad is not None
                ]
            ),
            p=2,
        )

    def maybe_no_sync(self):
        return self.model.no_sync() if is_initialized() else ExitStack()

    @torch.no_grad()
    def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs):
        r"""Set closure.

            Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
            automatically performs forward and backward passes. This function does not take any arguments,
            and the inputs and targets data should be pre-set in the definition of partial-function.

        :param loss_fn: nn.Module. loss function.
        :param inputs: torch.Tensor. inputs.
        :param targets: torch.Tensor. targets.
        """

        def get_grad():
            self.base_optimizer.zero_grad()
            with torch.enable_grad():
                outputs = self.model(inputs)
                loss = loss_fn(outputs, targets, **kwargs)

            loss.backward()

            return outputs, loss.detach()

        self.forward_backward_func = get_grad

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> Tuple[torch.Tensor, float]:
        get_grad = closure if closure else self.forward_backward_func

        with self.maybe_no_sync():
            outputs, loss = get_grad()

            self.perturb_weights(rho=self.rho_t)

            disable_running_stats(self.model)

            get_grad()

            self.gradient_decompose(self.alpha)

            self.un_perturb()

        self.sync_grad()

        self.base_optimizer.step()

        enable_running_stats(self.model)

        return outputs, loss

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

set_closure(loss_fn, inputs, targets, **kwargs)

Set closure.

Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
automatically performs forward and backward passes. This function does not take any arguments,
and the inputs and targets data should be pre-set in the definition of partial-function.

Parameters:

Name Type Description Default
loss_fn Module

nn.Module. loss function.

required
inputs Tensor

torch.Tensor. inputs.

required
targets Tensor

torch.Tensor. targets.

required
Source code in pytorch_optimizer/optimizer/sam.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
@torch.no_grad()
def set_closure(self, loss_fn: nn.Module, inputs: torch.Tensor, targets: torch.Tensor, **kwargs):
    r"""Set closure.

        Create `self.forward_backward_func`, which is a function such that `self.forward_backward_func()`
        automatically performs forward and backward passes. This function does not take any arguments,
        and the inputs and targets data should be pre-set in the definition of partial-function.

    :param loss_fn: nn.Module. loss function.
    :param inputs: torch.Tensor. inputs.
    :param targets: torch.Tensor. targets.
    """

    def get_grad():
        self.base_optimizer.zero_grad()
        with torch.enable_grad():
            outputs = self.model(inputs)
            loss = loss_fn(outputs, targets, **kwargs)

        loss.backward()

        return outputs, loss.detach()

    self.forward_backward_func = get_grad

Lamb

Bases: Optimizer, BaseOptimizer

Large Batch Optimization for Deep Learning.

This Lamb implementation is based on the paper v3, which does not use de-biasing.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
rectify bool

bool. perform the rectified update similar to RAdam.

False
degenerated_to_sgd bool

bool. degenerated to SGD.

False
n_sma_threshold int

int. (recommended is 5).

5
grad_averaging bool

bool. whether apply (1 - beta2) to gradient when calculating running averages of gradient.

True
max_grad_norm float

float. max gradient norm to clip.

1.0
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
adam bool

bool. always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.

False
pre_norm bool

bool. perform pre-normalization of all gradients.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
Source code in pytorch_optimizer/optimizer/lamb.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class Lamb(Optimizer, BaseOptimizer):
    r"""Large Batch Optimization for Deep Learning.

        This Lamb implementation is based on the paper v3, which does not use de-biasing.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param rectify: bool. perform the rectified update similar to RAdam.
    :param degenerated_to_sgd: bool. degenerated to SGD.
    :param n_sma_threshold: int. (recommended is 5).
    :param grad_averaging: bool. whether apply (1 - beta2) to gradient when calculating running averages of gradient.
    :param max_grad_norm: float. max gradient norm to clip.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param adam: bool. always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes.
    :param pre_norm: bool. perform pre-normalization of all gradients.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    clamp: float = 10.0

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = False,
        degenerated_to_sgd: bool = False,
        n_sma_threshold: int = 5,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        adam: bool = False,
        pre_norm: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(max_grad_norm, 'max_grad_norm')
        self.validate_non_negative(eps, 'eps')

        self.degenerated_to_sgd = degenerated_to_sgd
        self.n_sma_threshold = n_sma_threshold
        self.pre_norm = pre_norm

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'rectify': rectify,
            'grad_averaging': grad_averaging,
            'max_grad_norm': max_grad_norm,
            'adam': adam,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lamb'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
        if self.defaults['max_grad_norm'] == 0.0:
            return 1.0

        global_grad_norm = get_global_gradient_norm(self.param_groups, self.param_groups[0]['params'][0].device)
        global_grad_norm.sqrt_().add_(self.defaults['eps'])

        return torch.clamp(self.defaults['max_grad_norm'] / global_grad_norm, max=1.0)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        grad_norm = 1.0
        if self.pre_norm:
            grad_norm = self.get_global_gradient_norm()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            beta3: float = 1.0 - beta1 if group['grad_averaging'] else 1.0
            bias_correction1: float = 1.0 - beta1 ** group['step']

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=group['rectify'],
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if self.pre_norm:
                    grad.div_(grad_norm)

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=beta3)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['rectify']:
                    update = p.clone()
                    if n_sma >= self.n_sma_threshold:
                        de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                        update.addcdiv_(exp_avg, de_nom, value=-step_size)
                    else:
                        update.add_(exp_avg, alpha=-step_size)
                else:
                    update = exp_avg / exp_avg_sq.sqrt().add_(group['eps'])

                weight_norm = torch.linalg.norm(p).clamp_(min=0, max=self.clamp)
                p_norm = torch.linalg.norm(update)
                trust_ratio: float = 1.0 if weight_norm == 0 or p_norm == 0 else weight_norm / (p_norm + group['eps'])

                state['weight_norm'] = weight_norm
                state['adam_norm'] = p_norm
                state['trust_ratio'] = trust_ratio

                if group['adam']:
                    trust_ratio = 1.0

                if group['rectify']:
                    if n_sma >= self.n_sma_threshold:
                        p.addcdiv_(exp_avg, de_nom, value=-step_size * trust_ratio)
                    else:
                        p.add_(exp_avg, alpha=-step_size * trust_ratio)
                else:
                    p.add_(update, alpha=-step_size * trust_ratio)

        return loss

LARS

Bases: Optimizer, BaseOptimizer

Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
weight_decay float

float. weight decay (L2 penalty).

0.0
momentum float

float. momentum.

0.9
dampening float

float. dampening for momentum.

0.0
trust_coefficient float

float. trust_coefficient.

0.001
nesterov bool

bool. enables nesterov momentum.

False
Source code in pytorch_optimizer/optimizer/lars.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class LARS(Optimizer, BaseOptimizer):
    r"""Layer-wise Adaptive Rate Scaling (no rate scaling or weight decay for parameters <= 1D).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param momentum: float. momentum.
    :param dampening: float. dampening for momentum.
    :param trust_coefficient: float. trust_coefficient.
    :param nesterov: bool. enables nesterov momentum.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        weight_decay: float = 0.0,
        momentum: float = 0.9,
        dampening: float = 0.0,
        trust_coefficient: float = 1e-3,
        nesterov: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(dampening, 'dampening', 0.0, 1.0)
        self.validate_non_negative(trust_coefficient, 'trust_coefficient')

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'momentum': momentum,
            'dampening': dampening,
            'trust_coefficient': trust_coefficient,
            'nesterov': nesterov,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lars'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['mu'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                if p.ndim > 1:  # if not normalization gamma/beta or bias
                    param_norm = torch.linalg.norm(p)
                    update_norm = torch.linalg.norm(grad)

                    one = torch.ones_like(param_norm)

                    trust_ratio = torch.where(
                        param_norm > 0.0,
                        torch.where(update_norm > 0.0, (group['trust_coefficient'] * param_norm / update_norm), one),
                        one,
                    )

                    grad.add_(p, alpha=group['weight_decay'])
                    grad.mul_(trust_ratio)

                if group['momentum'] > 0.0:
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = grad.clone().detach()

                    mb = state['momentum_buffer']
                    mb.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])

                    if group['nesterov']:
                        grad.add_(mb, alpha=group['momentum'])
                    else:
                        grad.copy_(mb)

                p.add_(grad, alpha=-group['lr'])

        return loss

Lion

Bases: Optimizer, BaseOptimizer

Symbolic Discovery of Optimization Algorithms.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
use_gc bool

bool. use gradient centralization.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
Source code in pytorch_optimizer/optimizer/lion.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class Lion(Optimizer, BaseOptimizer):
    r"""Symbolic Discovery of Optimization Algorithms.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param use_gc: bool. use gradient centralization.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        betas: BETAS = (0.9, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        use_gc: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')

        self.use_gc = use_gc

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adanorm': adanorm,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Lion'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                if self.use_gc:
                    centralize_gradient(grad, gc_conv_only=False)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg = state['exp_avg']
                update = exp_avg.clone()

                update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
                exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)

                p.add_(update, alpha=-group['lr'])

        return loss

LOMO

Bases: BaseOptimizer, Optimizer

Full Parameter Fine-tuning for Large Language Models with Limited Resources.

Reference : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py Check the usage from here : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo_trainer.py

Parameters:

Name Type Description Default
model Module

nn.Module. pytorch model.

required
lr float

float. learning rate.

0.001
clip_grad_norm Optional[float]

Optional[float]. clip grad norm.

None
clip_grad_value Optional[float]

Optional[float]. clip grad value.

None
Source code in pytorch_optimizer/optimizer/lomo.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class LOMO(BaseOptimizer, Optimizer):
    r"""Full Parameter Fine-tuning for Large Language Models with Limited Resources.

    Reference : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo.py
    Check the usage from here : https://github.com/OpenLMLab/LOMO/blob/main/src/lomo_trainer.py

    :param model: nn.Module. pytorch model.
    :param lr: float. learning rate.
    :param clip_grad_norm: Optional[float]. clip grad norm.
    :param clip_grad_value: Optional[float]. clip grad value.
    """

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-3,
        clip_grad_norm: Optional[float] = None,
        clip_grad_value: Optional[float] = None,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(clip_grad_norm, 'clip_grad_norm')
        self.validate_non_negative(clip_grad_value, 'clip_grad_value')

        self.model = model
        self.lr = lr
        self.clip_grad_norm = clip_grad_norm
        self.clip_grad_value = clip_grad_value

        self.local_rank: int = int(os.environ.get('LOCAL_RANK', 0))

        self.gather_norm: bool = False
        self.grad_norms: List[torch.Tensor] = []
        self.clip_coef: Optional[float] = None

        p0: torch.Tensor = next(iter(self.model.parameters()))

        self.grad_func: Callable[[Any], Any] = (
            self.fuse_update_zero3() if hasattr(p0, 'ds_tensor') else self.fuse_update()
        )

        self.loss_scaler: Optional[DynamicLossScaler] = None
        if p0.dtype == torch.float16:
            if clip_grad_norm is None:
                raise ValueError(
                    '[-] Loss scaling is recommended to be used with grad norm to get better performance.'
                )

            self.loss_scaler = DynamicLossScaler(init_scale=2 ** 16)  # fmt: skip

        for _, p in self.model.named_parameters():
            if p.requires_grad:
                p.register_hook(self.grad_func)

        defaults: DEFAULTS = {'lr': lr}
        super().__init__(self.model.parameters(), defaults)

    def __str__(self) -> str:
        return 'LOMO'

    @torch.no_grad()
    def reset(self):
        pass

    def fuse_update(self) -> Callable[[Any], Any]:
        @torch.no_grad()
        def func(x: Any) -> Any:
            for _, p in self.model.named_parameters():
                if not p.requires_grad or p.grad is None:
                    continue

                if self.loss_scaler and self.loss_scaler.has_overflow_serial or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    if self.clip_grad_value is not None and self.clip_grad_value > 0.0:
                        grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)
                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is not None:
                        grad_fp32.mul_(self.clip_coef)

                    p_fp32 = p.to(torch.float32)
                    p_fp32.add_(grad_fp32, alpha=-self.lr)
                    p.copy_(p_fp32)

            return x

        return func

    def fuse_update_zero3(self) -> Callable[[Any], Any]:  # pragma: no cover
        @torch.no_grad()
        def func(x: torch.Tensor) -> torch.Tensor:
            for _, p in self.model.named_parameters():
                if p.grad is None:
                    continue

                all_reduce(p.grad, op=ReduceOp.AVG, async_op=False)

                if self.loss_scaler and self.loss_scaler.has_overflow_serial or has_overflow(p.grad):
                    p.grad = None
                    self.loss_scaler.has_overflow_serial = True
                    break

                grad_fp32 = p.grad.to(torch.float32)
                p.grad = None

                param_fp32 = p.ds_tensor.to(torch.float32)
                if self.loss_scaler:
                    grad_fp32.div_(self.loss_scaler.loss_scale)

                if self.gather_norm:
                    self.grad_norms.append(torch.norm(grad_fp32, 2.0))
                else:
                    one_dim_grad_fp32 = grad_fp32.view(-1)

                    partition_size: int = p.ds_tensor.numel()
                    start: int = partition_size * self.local_rank
                    end: int = min(start + partition_size, grad_fp32.numel())

                    partitioned_grad_fp32 = one_dim_grad_fp32.narrow(0, start, end - start)

                    if self.clip_grad_value is not None:
                        partitioned_grad_fp32.clamp_(min=-self.clip_grad_value, max=self.clip_grad_value)

                    if self.clip_grad_norm is not None and self.clip_grad_norm > 0 and self.clip_coef is not None:
                        partitioned_grad_fp32.mul_(self.clip_coef)

                    partitioned_p = param_fp32.narrow(0, 0, end - start)
                    partitioned_p.add_(partitioned_grad_fp32, alpha=-self.lr)

                    p.ds_tensor[: end - start] = partitioned_p  # fmt: skip

            return x

        return func

    def fused_backward(self, loss, lr: float):
        self.lr = lr

        if self.clip_grad_norm is not None and self.clip_grad_norm > 0.0 and self.clip_coef is None:
            raise ValueError(
                'clip_grad_norm is not None, but clip_coef is None. '
                'Please call optimizer.grad_norm() before optimizer.fused_backward().'
            )

        if self.loss_scaler:
            loss = loss * self.loss_scaler.loss_scale

        loss.backward()

        self.grad_func(0)

    def grad_norm(self, loss):
        self.gather_norm = True
        self.grad_norms = []

        if self.loss_scaler:
            self.loss_scaler.has_overflow_serial = False
            loss = loss * self.loss_scaler.loss_scale

        loss.backward(retain_graph=True)

        self.grad_func(0)

        if self.loss_scaler and self.loss_scaler.has_overflow_serial:
            self.loss_scaler.update_scale(overflow=True)

            with torch.no_grad():
                for _, p in self.model.named_parameters():
                    p.grad = None
            return

        with torch.no_grad():
            self.grad_norms = torch.stack(self.grad_norms)

            total_norm = torch.norm(self.grad_norms, 2.0)
            self.clip_coef = torch.clamp(float(self.clip_grad_norm) / (total_norm + 1e-6), max=1.0)

        self.gather_norm = False

Lookahead

Bases: Optimizer, BaseOptimizer

k steps forward, 1 step back.

Parameters:

Name Type Description Default
optimizer OPTIMIZER

OPTIMIZER. base optimizer.

required
k int

int. number of lookahead steps.

5
alpha float

float. linear interpolation factor.

0.5
pullback_momentum str

str. change to inner optimizer momentum on interpolation update.

'none'
Source code in pytorch_optimizer/optimizer/lookahead.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class Lookahead(Optimizer, BaseOptimizer):
    r"""k steps forward, 1 step back.

    :param optimizer: OPTIMIZER. base optimizer.
    :param k: int. number of lookahead steps.
    :param alpha: float. linear interpolation factor.
    :param pullback_momentum: str. change to inner optimizer momentum on interpolation update.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER,
        k: int = 5,
        alpha: float = 0.5,
        pullback_momentum: str = 'none',
    ):
        self.validate_positive(k, 'k')
        self.validate_range(alpha, 'alpha', 0.0, 1.0)
        self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])

        self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
        self._optimizer_step_post_hooks: Dict[int, Callable] = {}

        self.alpha = alpha
        self.k = k
        self.pullback_momentum = pullback_momentum

        self.optimizer = optimizer
        self.param_groups = self.optimizer.param_groups

        self.state: STATE = defaultdict(dict)

        for group in self.param_groups:
            if 'counter' not in group:
                group['counter'] = 0

            for p in group['params']:
                state = self.state[p]
                state['slow_params'] = torch.empty_like(p)
                state['slow_params'].copy_(p)
                if self.pullback_momentum == 'pullback':
                    state['slow_momentum'] = torch.zeros_like(p)

        self.defaults: DEFAULTS = {
            'lookahead_alpha': alpha,
            'lookahead_k': k,
            'lookahead_pullback_momentum': pullback_momentum,
            **optimizer.defaults,
        }

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'alpha': self.alpha,
            'k': self.k,
            'pullback_momentum': self.pullback_momentum,
        }

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['counter'] = 0

    def backup_and_load_cache(self):
        r"""Backup cache parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['backup_params'] = torch.empty_like(p)
                state['backup_params'].copy_(p)
                p.data.copy_(state['slow_params'])

    def clear_and_load_backup(self):
        r"""Load backup parameters."""
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                p.data.copy_(state['backup_params'])
                del state['backup_params']

    def state_dict(self) -> STATE:
        return self.optimizer.state_dict()

    def load_state_dict(self, state: STATE):
        r"""Load state."""
        self.optimizer.load_state_dict(state)

    @torch.no_grad()
    def zero_grad(self):
        self.optimizer.zero_grad(set_to_none=True)

    @torch.no_grad()
    def update(self, group: Dict):
        for p in group['params']:
            if p.grad is None:
                continue

            state = self.state[p]

            slow = state['slow_params']

            p.mul_(self.alpha).add_(slow, alpha=1.0 - self.alpha)
            slow.copy_(p)

            if 'momentum_buffer' not in self.optimizer.state[p]:
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

            if self.pullback_momentum == 'pullback':
                internal_momentum = self.optimizer.state[p]['momentum_buffer']
                self.optimizer.state[p]['momentum_buffer'] = internal_momentum.mul_(self.alpha).add_(
                    state['slow_momentum'], alpha=1.0 - self.alpha
                )
                state['slow_momentum'] = self.optimizer.state[p]['momentum_buffer']
            elif self.pullback_momentum == 'reset':
                self.optimizer.state[p]['momentum_buffer'] = torch.zeros_like(p)

    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = self.optimizer.step(closure)
        for group in self.param_groups:
            group['counter'] += 1
            if group['counter'] >= self.k:
                group['counter'] = 0
                self.update(group)
        return loss

backup_and_load_cache()

Backup cache parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
75
76
77
78
79
80
81
82
def backup_and_load_cache(self):
    r"""Backup cache parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            state['backup_params'] = torch.empty_like(p)
            state['backup_params'].copy_(p)
            p.data.copy_(state['slow_params'])

clear_and_load_backup()

Load backup parameters.

Source code in pytorch_optimizer/optimizer/lookahead.py
84
85
86
87
88
89
90
def clear_and_load_backup(self):
    r"""Load backup parameters."""
    for group in self.param_groups:
        for p in group['params']:
            state = self.state[p]
            p.data.copy_(state['backup_params'])
            del state['backup_params']

load_state_dict(state)

Load state.

Source code in pytorch_optimizer/optimizer/lookahead.py
95
96
97
def load_state_dict(self, state: STATE):
    r"""Load state."""
    self.optimizer.load_state_dict(state)

MADGRAD

Bases: Optimizer, BaseOptimizer

A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
eps float

float. term added to the denominator to improve numerical stability.

1e-06
weight_decay float

float. weight decay (L2 penalty). MADGRAD optimizer requires less weight decay than other methods, often as little as zero. On sparse problems both weight_decay and momentum should be set to 0.

0.0
weight_decouple bool

float. Apply AdamW style decoupled weight decay.

False
Source code in pytorch_optimizer/optimizer/madgrad.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class MADGRAD(Optimizer, BaseOptimizer):
    r"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param eps: float. term added to the denominator to improve numerical stability.
    :param weight_decay: float. weight decay (L2 penalty).
        MADGRAD optimizer requires less weight decay than other methods, often as little as zero.
        On sparse problems both weight_decay and momentum should be set to 0.
    :param weight_decouple: float. Apply AdamW style decoupled weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'momentum': momentum,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MADGRAD'

    @torch.no_grad()
    def reset(self):
        self.state['k'] = torch.tensor([0], dtype=torch.long, requires_grad=False)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['grad_sum_sq'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)
                if group['momentum'] > 0.0:
                    state['x0'] = torch.clone(p).detach()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # step counter must be stored in state to ensure correct behavior under optimizer sharding
        if 'k' not in self.state:
            self.state['k'] = torch.tensor([0], dtype=torch.long, requires_grad=False)

        for group in self.param_groups:
            weight_decay, momentum, eps = group['weight_decay'], group['momentum'], group['eps']
            lr = group['lr'] + eps

            _lambda = lr * math.pow(self.state['k'] + 1, 0.5)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                if 'grad_sum_sq' not in state:
                    state['grad_sum_sq'] = torch.zeros_like(p)
                    state['s'] = torch.zeros_like(p)
                    if momentum > 0.0:
                        state['x0'] = torch.clone(p).detach()

                if momentum > 0.0 and grad.is_sparse:
                    raise NoSparseGradientError(str(self), note='momentum > 0.0')

                grad_sum_sq, s = state['grad_sum_sq'], state['s']
                if weight_decay > 0.0 and not group['weight_decouple']:
                    if grad.is_sparse:
                        raise NoSparseGradientError(str(self), note='weight_decay')

                    # original implementation. not AdamW style
                    grad.add_(p, alpha=weight_decay)

                if grad.is_sparse:
                    grad = grad.coalesce()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    rms_masked_values = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
                    x0_masked_values = p_masked._values().addcdiv(s_masked._values(), rms_masked_values, value=1)

                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=_lambda)
                    grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)

                    rms_masked_values = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
                    if eps == 0.0:
                        rms_masked_values[rms_masked_values == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)
                    s_masked._values().add_(grad._values(), alpha=_lambda)

                    # update masked copy of p
                    p_kp1_masked_values = x0_masked_values.addcdiv(s_masked._values(), rms_masked_values, value=-1)

                    # Copy updated masked p to dense p using an add operation
                    p_masked._values().add_(p_kp1_masked_values, alpha=-1)
                    p.data.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0.0:
                        # Compute x_0 from other known quantities
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.addcdiv(s, rms, value=1)
                    else:
                        x0 = state['x0']

                    # Accumulate second moments
                    grad_sum_sq.addcmul_(grad, grad, value=_lambda)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    if eps == 0.0:
                        rms[rms == 0] = float('inf')

                    s.add_(grad, alpha=_lambda)

                    if weight_decay > 0.0 and group['weight_decouple']:
                        p_old = p.clone()

                    if momentum == 0.0:
                        p.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)
                        p.mul_(momentum).add_(z, alpha=1.0 - momentum)

                    if weight_decay > 0.0 and group['weight_decouple']:
                        p.add_(p_old, alpha=-lr * weight_decay)

        self.state['k'] += 1

        return loss

MSVAG

Bases: Optimizer, BaseOptimizer

Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. Moving average (momentum) constant (scalar tensor or float value).

0.9
Source code in pytorch_optimizer/optimizer/msvag.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class MSVAG(Optimizer, BaseOptimizer):
    r"""Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. Moving average (momentum) constant (scalar tensor or float value).
    """

    def __init__(self, params: PARAMETERS, lr: float = 1e-2, beta: float = 0.9):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        defaults: DEFAULTS = {'lr': lr, 'beta': beta}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'MSVAG'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['s'] = torch.zeros_like(p)

    @staticmethod
    def get_rho(beta_power: float, beta: float) -> float:
        r"""Get rho."""
        rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
        rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
        return min(rho, 0.9999)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta: float = group['beta']
            beta_power: float = beta ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['s'] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)
                exp_avg_sq.mul_(beta).addcmul_(grad, grad, value=1.0 - beta)

                m = exp_avg.div(beta_power)
                v = exp_avg_sq.div(beta_power)

                rho: float = self.get_rho(beta_power, beta)

                m_p2 = m.pow(2)
                s = (v - m_p2).div_(1.0 - rho)

                factor = m_p2.div(m_p2 + rho * s)
                torch.nan_to_num(factor, nan=0.0, out=factor)
                factor.clamp_(0.0, 1.0)

                p.add_(m * factor, alpha=-group['lr'])

        return loss

get_rho(beta_power, beta) staticmethod

Get rho.

Source code in pytorch_optimizer/optimizer/msvag.py
37
38
39
40
41
42
@staticmethod
def get_rho(beta_power: float, beta: float) -> float:
    r"""Get rho."""
    rho: float = (1.0 - beta_power ** 2) * (1.0 - beta) ** 2  # fmt: skip
    rho /= (1.0 - beta) * (1.0 - beta_power) ** 2
    return min(rho, 0.9999)

Nero

Bases: Optimizer, BaseOptimizer

Learning by Turning: Neural Architecture Aware Optimisation.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. coefficients used for computing running averages of gradient and the squared hessian trace.

0.999
constraints bool

bool.

True
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/nero.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class Nero(Optimizer, BaseOptimizer):
    """Learning by Turning: Neural Architecture Aware Optimisation.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param constraints: bool.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, constraints: bool = True, eps: float = 1e-8
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {'lr': lr, 'beta': beta, 'constraints': constraints, 'eps': eps}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Nero'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p) + group['eps'])

                state = self.state[p]

                state['step'] = 0
                state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))
                state['scale'] = neuron_norm(p).mean()

                if state['scale'] == 0.0:
                    state['scale'] = 0.01

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    if group['constraints'] and p.dim() > 1:
                        p.sub_(neuron_mean(p))
                        p.div_(neuron_norm(p) + group['eps'])

                    state['step'] = 0
                    state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))
                    state['scale'] = neuron_norm(p).mean()
                    if state['scale'] == 0.0:
                        state['scale'] = 0.01

                state['step'] += 1

                grad_norm = neuron_norm(grad)

                exp_avg_sq = state['exp_avg_sq']
                exp_avg_sq.mul_(group['beta']).addcmul_(grad_norm, grad_norm, value=1.0 - group['beta'])

                bias_correction: float = 1.0 - group['beta'] ** state['step']

                grad_normed = grad / ((exp_avg_sq / bias_correction).sqrt_().add_(group['eps']))
                torch.nan_to_num(grad_normed, nan=0.0, out=grad_normed)

                p.sub_(grad_normed, alpha=group['lr'] * state['scale'])

                if group['constraints'] and p.dim() > 1:
                    p.sub_(neuron_mean(p))
                    p.div_(neuron_norm(p) + group['eps'])

        return loss

NovoGrad

Bases: Optimizer, BaseOptimizer

Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.98)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
grad_averaging bool

bool. multiply ck (1 - momentum).

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/novograd.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class NovoGrad(Optimizer, BaseOptimizer):
    r"""Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param grad_averaging: bool. multiply ck (1 - momentum).
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.95, 0.98),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        grad_averaging: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'grad_averaging': grad_averaging,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'NovoGrad'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                grad = p.grad

                g_2 = grad.pow(2).sum()  # fmt: skip

                state['moments'] = grad.div(g_2.sqrt() + group['eps']) + group['weight_decay'] * p
                state['grads_ema'] = g_2

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            step_size: float = group['lr'] * bias_correction2_sq
            if not group['adam_debias']:
                step_size /= bias_correction1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                grad_p2 = grad.pow(2).sum()

                if len(state) == 0:
                    state['moments'] = grad.div(grad_p2.sqrt() + group['eps']) + group['weight_decay'] * p
                    state['grads_ema'] = grad_p2

                grads_ema = state['grads_ema']
                grads_ema.mul_(beta2).add_(grad_p2, alpha=1.0 - beta2)

                de_nom = grads_ema.sqrt().add_(group['eps'])
                grad.div_(de_nom)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['grad_averaging']:
                    grad.mul_(1.0 - beta1)

                moments = state['moments']
                moments.mul_(beta1).add_(grad)

                p.add_(moments, alpha=-step_size)

        return loss

PAdam

Bases: Optimizer, BaseOptimizer

Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
partial float

float. partially adaptive parameter.

0.25
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/padam.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class PAdam(Optimizer, BaseOptimizer):
    """Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param partial: float. partially adaptive parameter.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        betas: BETAS = (0.9, 0.999),
        partial: float = 0.25,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(partial, 'partial', 0.0, 1.0, range_type='(]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'partial': partial,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PAdam'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = exp_avg_sq.sqrt().add_(group['eps'])

                step_size: float = group['lr'] * bias_correction2_sq / bias_correction1

                p.addcdiv_(exp_avg, de_nom ** (group['partial'] * 2), value=-step_size)

        return loss

PCGrad

Bases: BaseOptimizer

Gradient Surgery for Multi-Task Learning.

Parameters:

Name Type Description Default
optimizer OPTIMIZER

OPTIMIZER: optimizer instance.

required
reduction str

str. reduction method.

'mean'
Source code in pytorch_optimizer/optimizer/pcgrad.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class PCGrad(BaseOptimizer):
    r"""Gradient Surgery for Multi-Task Learning.

    :param optimizer: OPTIMIZER: optimizer instance.
    :param reduction: str. reduction method.
    """

    def __init__(self, optimizer: OPTIMIZER, reduction: str = 'mean'):
        self.validate_options(reduction, 'reduction', ['mean', 'sum'])

        self.optimizer = optimizer
        self.reduction = reduction

    @torch.no_grad()
    def reset(self):
        self.zero_grad()

    def zero_grad(self):
        return self.optimizer.zero_grad(set_to_none=True)

    def step(self):
        return self.optimizer.step()

    def set_grad(self, grads: List[torch.Tensor]):
        idx: int = 0
        for group in self.optimizer.param_groups:
            for p in group['params']:
                p.grad = grads[idx]
                idx += 1

    def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
        r"""Get the gradient of the parameters of the network with specific objective."""
        grad, shape, has_grad = [], [], []
        for group in self.optimizer.param_groups:
            for p in group['params']:
                if p.grad is None:
                    shape.append(p.shape)
                    grad.append(torch.zeros_like(p, device=p.device))
                    has_grad.append(torch.zeros_like(p, device=p.device))
                    continue

                shape.append(p.grad.shape)
                grad.append(p.grad.clone())
                has_grad.append(torch.ones_like(p, device=p.device))

        return grad, shape, has_grad

    def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
        r"""Pack the gradient of the parameters of the network for each objective.

        :param objectives: Iterable[nn.Module]. a list of objectives.
        :return: torch.Tensor. packed gradients.
        """
        grads, shapes, has_grads = [], [], []
        for objective in objectives:
            self.optimizer.zero_grad(set_to_none=True)
            objective.backward(retain_graph=True)

            grad, shape, has_grad = self.retrieve_grad()

            grads.append(flatten_grad(grad))
            has_grads.append(flatten_grad(has_grad))
            shapes.append(shape)

        return grads, shapes, has_grads

    def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
        r"""Project conflicting.

        :param grads: a list of the gradient of the parameters.
        :param has_grads: a list of mask represent whether the parameter has gradient.
        :return: torch.Tensor. merged gradients.
        """
        shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

        pc_grad: List[torch.Tensor] = deepcopy(grads)
        for i, g_i in enumerate(pc_grad):
            random.shuffle(grads)
            for g_j in grads:
                g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

        merged_grad: torch.Tensor = torch.zeros_like(grads[0])

        shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
        if self.reduction == 'mean':
            merged_grad[shared] = shared_pc_gradients.mean(dim=0)
        else:
            merged_grad[shared] = shared_pc_gradients.sum(dim=0)

        merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

        return merged_grad

    def pc_backward(self, objectives: Iterable[nn.Module]):
        r"""Calculate the gradient of the parameters.

        :param objectives: Iterable[nn.Module]. a list of objectives.
        """
        grads, shapes, has_grads = self.pack_grad(objectives)

        pc_grad = self.project_conflicting(grads, has_grads)
        pc_grad = un_flatten_grad(pc_grad, shapes[0])

        self.set_grad(pc_grad)

pack_grad(objectives)

Pack the gradient of the parameters of the network for each objective.

Parameters:

Name Type Description Default
objectives Iterable

Iterable[nn.Module]. a list of objectives.

required

Returns:

Type Description
Tuple[List[Tensor], List[List[int]], List[Tensor]]

torch.Tensor. packed gradients.

Source code in pytorch_optimizer/optimizer/pcgrad.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def pack_grad(self, objectives: Iterable) -> Tuple[List[torch.Tensor], List[List[int]], List[torch.Tensor]]:
    r"""Pack the gradient of the parameters of the network for each objective.

    :param objectives: Iterable[nn.Module]. a list of objectives.
    :return: torch.Tensor. packed gradients.
    """
    grads, shapes, has_grads = [], [], []
    for objective in objectives:
        self.optimizer.zero_grad(set_to_none=True)
        objective.backward(retain_graph=True)

        grad, shape, has_grad = self.retrieve_grad()

        grads.append(flatten_grad(grad))
        has_grads.append(flatten_grad(has_grad))
        shapes.append(shape)

    return grads, shapes, has_grads

pc_backward(objectives)

Calculate the gradient of the parameters.

Parameters:

Name Type Description Default
objectives Iterable[Module]

Iterable[nn.Module]. a list of objectives.

required
Source code in pytorch_optimizer/optimizer/pcgrad.py
108
109
110
111
112
113
114
115
116
117
118
def pc_backward(self, objectives: Iterable[nn.Module]):
    r"""Calculate the gradient of the parameters.

    :param objectives: Iterable[nn.Module]. a list of objectives.
    """
    grads, shapes, has_grads = self.pack_grad(objectives)

    pc_grad = self.project_conflicting(grads, has_grads)
    pc_grad = un_flatten_grad(pc_grad, shapes[0])

    self.set_grad(pc_grad)

project_conflicting(grads, has_grads)

Project conflicting.

Parameters:

Name Type Description Default
grads List[Tensor]

a list of the gradient of the parameters.

required
has_grads List[Tensor]

a list of mask represent whether the parameter has gradient.

required

Returns:

Type Description
Tensor

torch.Tensor. merged gradients.

Source code in pytorch_optimizer/optimizer/pcgrad.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def project_conflicting(self, grads: List[torch.Tensor], has_grads: List[torch.Tensor]) -> torch.Tensor:
    r"""Project conflicting.

    :param grads: a list of the gradient of the parameters.
    :param has_grads: a list of mask represent whether the parameter has gradient.
    :return: torch.Tensor. merged gradients.
    """
    shared: torch.Tensor = torch.stack(has_grads).prod(0).bool()

    pc_grad: List[torch.Tensor] = deepcopy(grads)
    for i, g_i in enumerate(pc_grad):
        random.shuffle(grads)
        for g_j in grads:
            g_i_g_j: torch.Tensor = torch.dot(g_i, g_j)
            if g_i_g_j < 0:
                pc_grad[i] -= g_i_g_j * g_j / (g_j.norm() ** 2)

    merged_grad: torch.Tensor = torch.zeros_like(grads[0])

    shared_pc_gradients: torch.Tensor = torch.stack([g[shared] for g in pc_grad])
    if self.reduction == 'mean':
        merged_grad[shared] = shared_pc_gradients.mean(dim=0)
    else:
        merged_grad[shared] = shared_pc_gradients.sum(dim=0)

    merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)

    return merged_grad

retrieve_grad()

Get the gradient of the parameters of the network with specific objective.

Source code in pytorch_optimizer/optimizer/pcgrad.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def retrieve_grad(self) -> Tuple[List[torch.Tensor], List[int], List[torch.Tensor]]:
    r"""Get the gradient of the parameters of the network with specific objective."""
    grad, shape, has_grad = [], [], []
    for group in self.optimizer.param_groups:
        for p in group['params']:
            if p.grad is None:
                shape.append(p.shape)
                grad.append(torch.zeros_like(p, device=p.device))
                has_grad.append(torch.zeros_like(p, device=p.device))
                continue

            shape.append(p.grad.shape)
            grad.append(p.grad.clone())
            has_grad.append(torch.ones_like(p, device=p.device))

    return grad, shape, has_grad

PID

Bases: Optimizer, BaseOptimizer

A PID Controller Approach for Stochastic Optimization of Deep Networks.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
dampening float

float. dampening for momentum.

0.0
derivative float

float. D part of the PID.

10.0
integral float

float. I part of the PID.

5.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/pid.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class PID(Optimizer, BaseOptimizer):
    r"""A PID Controller Approach for Stochastic Optimization of Deep Networks.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param dampening: float. dampening for momentum.
    :param derivative: float. D part of the PID.
    :param integral: float. I part of the PID.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        derivative: float = 10.0,
        integral: float = 5.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(derivative, 'derivative')
        self.validate_non_negative(integral, 'integral')
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'dampening': dampening,
            'derivative': derivative,
            'integral': integral,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PID'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                if group['momentum'] > 0.0:
                    state['grad_buffer'] = torch.zeros_like(p)
                    state['i_buffer'] = torch.zeros_like(p)
                    state['d_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0 and group['momentum'] > 0.0:
                    state['grad_buffer'] = torch.zeros_like(p)
                    state['i_buffer'] = torch.zeros_like(p)
                    state['d_buffer'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['momentum'] > 0.0:
                    i_buf = state['i_buffer']
                    i_buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['dampening'])

                    g_buf, d_buf = state['grad_buffer'], state['d_buffer']
                    d_buf.mul_(group['momentum'])

                    if group['step'] > 1:
                        d_buf.add_(grad - g_buf, alpha=1.0 - group['momentum'])
                        g_buf.copy_(grad)

                    grad.add_(i_buf, alpha=group['integral']).add_(d_buf, alpha=group['derivative'])

                p.add_(grad, alpha=-group['lr'])

        return loss

PNM

Bases: Optimizer, BaseOptimizer

Positive-Negative Momentum Optimizers.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use weight_decouple.

True
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/pnm.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class PNM(Optimizer, BaseOptimizer):
    r"""Positive-Negative Momentum Optimizers.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use weight_decouple.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'PNM'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['pos_momentum'] = torch.zeros_like(p)
                state['neg_momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            noise_norm: float = math.sqrt((1 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['pos_momentum'] = torch.zeros_like(p)
                    state['neg_momentum'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['step'] % 2 == 1:
                    pos_momentum, neg_momentum = state['pos_momentum'], state['neg_momentum']
                else:
                    neg_momentum, pos_momentum = state['pos_momentum'], state['neg_momentum']

                pos_momentum.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2)  # fmt: skip

                delta_p = pos_momentum.mul(1 + beta2).add_(neg_momentum, alpha=-beta2).mul_(1.0 / noise_norm)
                p.add_(delta_p, alpha=-group['lr'])

        return loss

Prodigy

Bases: Optimizer, BaseOptimizer

An Expeditiously Adaptive Parameter-Free Learner.

Leave LR set to 1 unless you encounter instability.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

1.0
betas BETAS

BETAS. betas.

(0.9, 0.999)
beta3 Optional[float]

float. coefficients for computing the Prodidy step-size using running averages. If set to None, uses the value of square root of beta2.

None
d0 float

float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.

1e-06
d_coef float

float. Coefficient in the expression for the estimate of d.

1.0
growth_rate float

float. prevent the D estimate from growing faster than this multiplicative rate.

float('inf')
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. use AdamW style weight decay.

True
fixed_decay bool

bool. fix weight decay.

False
bias_correction bool

bool. turn on Adam's bias correction.

False
safeguard_warmup bool

bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/prodigy.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class Prodigy(Optimizer, BaseOptimizer):
    r"""An Expeditiously Adaptive Parameter-Free Learner.

        Leave LR set to 1 unless you encounter instability.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. betas.
    :param beta3: float. coefficients for computing the Prodidy step-size using running averages. If set to None,
        uses the value of square root of beta2.
    :param d0: float. initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
    :param d_coef: float. Coefficient in the expression for the estimate of d.
    :param growth_rate: float. prevent the D estimate from growing faster than this multiplicative rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. use AdamW style weight decay.
    :param fixed_decay: bool. fix weight decay.
    :param bias_correction: bool. turn on Adam's bias correction.
    :param safeguard_warmup: bool. remove lr from the denominator of D estimate to avoid issues during warm-up stage.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1.0,
        betas: BETAS = (0.9, 0.999),
        beta3: Optional[float] = None,
        d0: float = 1e-6,
        d_coef: float = 1.0,
        growth_rate: float = float('inf'),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        bias_correction: bool = False,
        safeguard_warmup: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas((*betas, beta3))
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'beta3': beta3,
            'd': d0,
            'd0': d0,
            'd_max': d0,
            'd_coef': d_coef,
            'growth_rate': growth_rate,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'bias_correction': bias_correction,
            'safeguard_warmup': safeguard_warmup,
            'step': 1,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Prodigy'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 1
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                state['s'] = torch.zeros_like(p)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        device = group['params'][0].device

        d_de_nom = torch.tensor([0.0], device=device)

        beta1, beta2 = group['betas']
        beta3 = group['beta3'] if group['beta3'] is not None else math.sqrt(beta2)

        bias_correction1: float = 1.0 - beta1 ** group['step']
        bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])
        bias_correction: float = (bias_correction1 / bias_correction2_sq) if group['bias_correction'] else 1.0

        d, d0 = group['d'], group['d0']
        d_lr: float = d * group['lr'] / bias_correction

        if 'd_numerator' not in group:
            group['d_numerator'] = torch.tensor([0.0], device=device)

        d_numerator = group['d_numerator']
        d_numerator.mul_(beta3)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['s'] = torch.zeros_like(p)
                    state['p0'] = p.clone()
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                p0, exp_avg, exp_avg_sq = state['p0'], state['exp_avg'], state['exp_avg_sq']

                d_numerator.add_(torch.dot(grad.flatten(), (p0 - p).flatten()), alpha=(d / d0) * d_lr)

                exp_avg.mul_(beta1).add_(grad, alpha=d * (1.0 - beta1))
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1.0 - beta2))

                s = state['s']
                s.mul_(beta3).add_(grad, alpha=(d / d0) * (d if group['safeguard_warmup'] else d_lr))

                d_de_nom.add_(s.abs().sum())

        if d_de_nom == 0:
            return loss

        d_hat = (group['d_coef'] * d_numerator / d_de_nom).item()
        if d == group['d0']:
            d = max(d, d_hat)

        d_max = max(group['d_max'], d_hat)
        d = min(d_max, d * group['growth_rate'])

        for group in self.param_groups:
            group['step'] += 1

            group['d_numerator'] = d_numerator
            group['d_de_nom'] = d_de_nom
            group['d'] = d
            group['d_max'] = d_max
            group['d_hat'] = d_hat

            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                de_nom = exp_avg_sq.sqrt().add_(d * group['eps'])

                self.apply_weight_decay(
                    p,
                    p.grad,
                    lr=d_lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.addcdiv_(exp_avg, de_nom, value=-d_lr)

        return loss

QHAdam

Bases: Optimizer, BaseOptimizer

Quasi-hyperbolic momentum and Adam for deep learning.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
nus Tuple[float, float]

Tuple[float, float]. immediate discount factors used to estimate the gradient and its square.

(1.0, 1.0)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/qhadam.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class QHAdam(Optimizer, BaseOptimizer):
    r"""Quasi-hyperbolic momentum and Adam for deep learning.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param nus: Tuple[float, float]. immediate discount factors used to estimate the gradient and its square.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        nus: Tuple[float, float] = (1.0, 1.0),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_nus(nus)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'nus': nus,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHAdam'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['beta1_weight'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['beta2_weight'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']
            nu1, nu2 = group['nus']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['beta1_weight'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    state['beta2_weight'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                beta1_weight, beta2_weight = state['beta1_weight'], state['beta2_weight']
                beta1_weight.mul_(beta1).add_(1.0)
                beta2_weight.mul_(beta2).add_(1.0)

                beta1_adj = 1.0 - (1.0 / beta1_weight)
                beta2_adj = 1.0 - (1.0 / beta2_weight)

                grad_p2 = grad.pow(2)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1_adj).add_((1.0 - beta1_adj) * grad)
                exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj * grad_p2)

                avg_grad = exp_avg.mul(nu1)
                if nu1 != 1.0:
                    avg_grad.add_(grad, alpha=1.0 - nu1)

                avg_grad_rms = exp_avg_sq.mul(nu2)
                if nu2 != 1.0:
                    avg_grad_rms.add_(grad_p2, alpha=1.0 - nu2)

                avg_grad_rms.sqrt_().add_(group['eps'])

                p.addcdiv_(avg_grad, avg_grad_rms, value=-group['lr'])

        return loss

QHM

Bases: Optimizer, BaseOptimizer

Quasi-hyperbolic momentum (QHM) optimization algorithm.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
nu float

float. immediate discount factor used to estimate the gradient and its square.

1.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/qhm.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class QHM(Optimizer, BaseOptimizer):
    r"""Quasi-hyperbolic momentum (QHM) optimization algorithm.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param nu: float. immediate discount factor used to estimate the gradient and its square.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        nu: float = 1.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_nus(nu)

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'nu': nu,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'QHM'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['momentum_buffer'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                buf = state['momentum_buffer']
                buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])

                p.add_(buf, alpha=-group['lr'] * group['nu'])
                p.add_(grad, alpha=-group['lr'] * (1.0 - group['nu']))

        return loss

RAdam

Bases: Optimizer, BaseOptimizer

Rectified Adam.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

float. degenerated to SGD.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/radam.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class RAdam(Optimizer, BaseOptimizer):
    r"""Rectified Adam.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: float. degenerated to SGD.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'RAdam'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                if step_size > 0 or n_sma >= self.n_sma_threshold:
                    self.apply_weight_decay(
                        p=p,
                        grad=None,
                        lr=group['lr'],
                        weight_decay=group['weight_decay'],
                        weight_decouple=group['weight_decouple'],
                        fixed_decay=group['fixed_decay'],
                    )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                elif step_size > 0:
                    p.add_(exp_avg, alpha=-step_size)

        return loss

Ranger

Bases: Optimizer, BaseOptimizer

a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.95, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
n_sma_threshold int

int. (recommended is 5).

5
degenerated_to_sgd bool

bool. perform SGD update when variance of gradient is high.

False
use_gc bool

bool. use Gradient Centralization (both convolution & fc layers).

True
gc_conv_only bool

bool. use Gradient Centralization (only convolution layer).

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-05
Source code in pytorch_optimizer/optimizer/ranger.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class Ranger(Optimizer, BaseOptimizer):
    r"""a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param n_sma_threshold: int. (recommended is 5).
    :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
    :param use_gc: bool. use Gradient Centralization (both convolution & fc layers).
    :param gc_conv_only: bool. use Gradient Centralization (only convolution layer).
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.95, 0.999),
        alpha: float = 0.5,
        k: int = 6,
        n_sma_threshold: int = 5,
        degenerated_to_sgd: bool = False,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        use_gc: bool = True,
        gc_conv_only: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-5,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_range(alpha, 'alpha', 0.0, 1.0, range_type='[]')
        self.validate_positive(k, 'k')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        self.n_sma_threshold = n_sma_threshold
        self.degenerated_to_sgd = degenerated_to_sgd
        self.use_gc = use_gc
        self.gc_gradient_threshold: int = 3 if gc_conv_only else 1

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'alpha': alpha,
            'k': k,
            'step_counter': 0,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Ranger'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['slow_buffer'] = torch.empty_like(p)
                state['slow_buffer'].copy_(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']
            bias_correction1: float = 1.0 - beta1 ** group['step']

            step_size, n_sma = self.get_rectify_step_size(
                is_rectify=True,
                step=group['step'],
                lr=group['lr'],
                beta2=beta2,
                n_sma_threshold=self.n_sma_threshold,
                degenerated_to_sgd=self.degenerated_to_sgd,
            )

            step_size = self.apply_adam_debias(
                adam_debias=group['adam_debias'],
                step_size=step_size,
                bias_correction1=bias_correction1,
            )

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['slow_buffer'] = torch.empty_like(p)
                    state['slow_buffer'].copy_(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                if self.use_gc and grad.dim() > self.gc_gradient_threshold:
                    centralize_gradient(grad, gc_conv_only=False)

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                if n_sma >= self.n_sma_threshold:
                    de_nom = exp_avg_sq.sqrt().add_(group['eps'])
                    p.addcdiv_(exp_avg, de_nom, value=-step_size)
                else:
                    p.add_(exp_avg, alpha=-step_size)

                if group['step'] % group['k'] == 0:
                    slow_p = state['slow_buffer']
                    slow_p.add_(p - slow_p, alpha=group['alpha'])
                    p.copy_(slow_p)

        return loss

Ranger21

Bases: Optimizer, BaseOptimizer

Integrating the latest deep learning components into a single optimizer.

Here's the components
    * uses the AdamW optimizer as its core (or, optionally, MadGrad)
    * Adaptive gradient clipping
    * Gradient centralization
    * Positive-Negative momentum
    * Norm loss
    * Stable weight decay
    * Linear learning rate warm-up
    * Explore-exploit learning rate schedule
    * Lookahead
    * Softplus transformation
    * Gradient Normalization
    * Corrects the denominator (AdamD).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
num_iterations int

int. number of the total training steps. Ranger21 optimizer schedules the learning rate with its own recipes.

required
lr float

float. learning rate.

0.001
beta0 float

float. Manages the amplitude of the noise introduced by positive negative momentum While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.

0.9
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
use_softplus bool

bool. use softplus to smooth.

True
beta_softplus float

float. beta.

50.0
num_warm_up_iterations Optional[int]

Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate warmup.

None
num_warm_down_iterations Optional[int]

Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit learning rate scheduling.

None
agc_clipping_value float

float.

0.01
agc_eps float

float. eps for AGC

0.001
centralize_gradients bool

bool. use GC both convolution & fc layers.

True
normalize_gradients bool

bool. use gradient normalization.

True
lookahead_merge_time int

int. merge time.

5
lookahead_blending_alpha float

float. blending alpha.

0.5
weight_decay float

float. weight decay (L2 penalty).

0.0001
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
norm_loss_factor float

float. norm loss factor.

0.0001
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/ranger21.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
class Ranger21(Optimizer, BaseOptimizer):
    r"""Integrating the latest deep learning components into a single optimizer.

        Here's the components
            * uses the AdamW optimizer as its core (or, optionally, MadGrad)
            * Adaptive gradient clipping
            * Gradient centralization
            * Positive-Negative momentum
            * Norm loss
            * Stable weight decay
            * Linear learning rate warm-up
            * Explore-exploit learning rate schedule
            * Lookahead
            * Softplus transformation
            * Gradient Normalization
            * Corrects the denominator (AdamD).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param num_iterations: int. number of the total training steps. Ranger21 optimizer schedules the learning rate
        with its own recipes.
    :param lr: float. learning rate.
    :param beta0: float. Manages the amplitude of the noise introduced by positive negative momentum
        While 0.9 is a recommended default value, you can use -0.5 to minimize the noise.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param use_softplus: bool. use softplus to smooth.
    :param beta_softplus: float. beta.
    :param num_warm_up_iterations: Optional[int]. number of warm-up iterations. Ranger21 performs linear learning rate
        warmup.
    :param num_warm_down_iterations: Optional[int]. number of warm-down iterations. Ranger21 performs Explore-exploit
        learning rate scheduling.
    :param agc_clipping_value: float.
    :param agc_eps: float. eps for AGC
    :param centralize_gradients: bool. use GC both convolution & fc layers.
    :param normalize_gradients: bool. use gradient normalization.
    :param lookahead_merge_time: int. merge time.
    :param lookahead_blending_alpha: float. blending alpha.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param norm_loss_factor: float. norm loss factor.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(  # pylint: disable=R0913
        self,
        params: PARAMETERS,
        num_iterations: int,
        lr: float = 1e-3,
        beta0: float = 0.9,
        betas: BETAS = (0.9, 0.999),
        use_softplus: bool = True,
        beta_softplus: float = 50.0,
        num_warm_up_iterations: Optional[int] = None,
        num_warm_down_iterations: Optional[int] = None,
        warm_down_min_lr: float = 3e-5,
        agc_clipping_value: float = 1e-2,
        agc_eps: float = 1e-3,
        centralize_gradients: bool = True,
        normalize_gradients: bool = True,
        lookahead_merge_time: int = 5,
        lookahead_blending_alpha: float = 0.5,
        weight_decay: float = 1e-4,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        norm_loss_factor: float = 1e-4,
        adam_debias: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_learning_rate(warm_down_min_lr)
        self.validate_betas(betas)
        self.validate_range(beta0, 'beta0', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(agc_clipping_value, 'agc_clipping_value')
        self.validate_non_negative(eps, 'eps')
        self.validate_non_negative(agc_eps, 'agc_eps')

        self.min_lr = warm_down_min_lr
        self.use_softplus = use_softplus
        self.beta_softplus = beta_softplus
        self.agc_clipping_value = agc_clipping_value
        self.agc_eps = agc_eps
        self.centralize_gradients = centralize_gradients
        self.normalize_gradients = normalize_gradients
        self.lookahead_merge_time = lookahead_merge_time
        self.lookahead_blending_alpha = lookahead_blending_alpha
        self.norm_loss_factor = norm_loss_factor

        self.lookahead_step: int = 0
        self.starting_lr: float = lr
        self.current_lr: float = lr

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        super().__init__(params, defaults)

        self.num_warm_up_iterations: int = (
            self.build_warm_up_iterations(num_iterations, betas[1])
            if num_warm_up_iterations is None
            else num_warm_up_iterations
        )
        self.num_warm_down_iterations: int = (
            self.build_warm_down_iterations(num_iterations)
            if num_warm_down_iterations is None
            else num_warm_down_iterations
        )
        self.start_warm_down: int = num_iterations - self.num_warm_down_iterations
        self.warm_down_lr_delta: float = self.starting_lr - self.min_lr

    def __str__(self) -> str:
        return 'Ranger21'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['grad_ma'] = torch.zeros_like(p)
                state['variance_ma'] = torch.zeros_like(p)
                state['lookahead_params'] = p.clone()
                state['neg_grad_ma'] = torch.zeros_like(p)
                state['max_variance_ma'] = torch.zeros_like(p)

    @staticmethod
    def build_warm_up_iterations(total_iterations: int, beta2: float, warm_up_pct: float = 0.22) -> int:
        warm_up_iterations: int = math.ceil(2.0 / (1.0 - beta2))  # default un-tuned linear warmup
        beta_pct: float = warm_up_iterations / total_iterations
        return int(warm_up_pct * total_iterations) if beta_pct > 0.45 else warm_up_iterations

    @staticmethod
    def build_warm_down_iterations(total_iterations: int, warm_down_pct: float = 0.72) -> int:
        start_warm_down: int = int(warm_down_pct * total_iterations)
        return total_iterations - start_warm_down

    def warm_up_dampening(self, lr: float, step: int) -> float:
        if step > self.num_warm_up_iterations:
            return lr

        warm_up_current_pct: float = min(1.0, (step / self.num_warm_up_iterations))

        self.current_lr = lr * warm_up_current_pct

        return self.current_lr

    def warm_down(self, lr: float, iteration: int) -> float:
        if iteration < self.start_warm_down:
            return lr

        # start iteration from 1, not 0
        warm_down_iteration: int = max((iteration + 1) - self.start_warm_down, 1)
        warm_down_pct: float = min(warm_down_iteration / (self.num_warm_down_iterations + 1), 1.0)

        self.current_lr = max(self.starting_lr - self.warm_down_lr_delta * warm_down_pct, self.min_lr)

        return self.current_lr

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        param_size: int = 0
        variance_ma_sum: float = 1.0

        # Phase 1 - Accumulate all the variance_ma_sum to use in stable weight decay
        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction2: float = 1.0 - beta2 ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                param_size += p.numel()

                state = self.state[p]
                if len(state) == 0:
                    state['grad_ma'] = torch.zeros_like(p)
                    state['variance_ma'] = torch.zeros_like(p)
                    state['lookahead_params'] = p.clone()
                    state['neg_grad_ma'] = torch.zeros_like(p)
                    state['max_variance_ma'] = torch.zeros_like(p)

                # Apply Adaptive Gradient Clipping (AGC)
                grad.copy_(agc(p, grad, self.agc_eps, self.agc_clipping_value))

                # Apply gradient centralization & normalization
                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                # second moment estimation
                # using positive-negative momentum and bias correction
                variance_ma = state['variance_ma']
                variance_ma.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
                variance_ma_sum += (variance_ma / bias_correction2).sum()

        if param_size == 0:
            raise ZeroParameterSizeError()

        variance_normalized = math.sqrt(variance_ma_sum / param_size)

        # Phase 2 - Apply weight decay and step
        for group in self.param_groups:
            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']  # fmt: skip
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])  # fmt: skip

            noise_norm: float = math.sqrt((1.0 + beta2) ** 2 + beta2 ** 2)  # fmt: skip

            # warm up & down
            lr: float = self.warm_up_dampening(group['lr'], group['step'])
            lr = self.warm_down(lr, group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                # stable weight decay
                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=lr,
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=1.0 / variance_normalized,
                )

                # norm loss
                correction = 2.0 * self.norm_loss_factor * (1.0 - 1.0 / unit_norm(p).add_(group['eps']))
                p.mul_(1.0 - lr * correction)

                state = self.state[p]
                if group['step'] % 2 == 1:
                    grad_ma, neg_grad_ma = state['grad_ma'], state['neg_grad_ma']
                else:
                    grad_ma, neg_grad_ma = state['neg_grad_ma'], state['grad_ma']

                variance_ma = state['variance_ma']
                torch.max(state['max_variance_ma'], variance_ma, out=variance_ma)

                de_nom = (variance_ma.sqrt() / bias_correction2_sq).add_(group['eps'])

                if self.use_softplus:
                    de_nom = f.softplus(de_nom, beta=self.beta_softplus)

                grad = p.grad
                centralize_gradient(grad, gc_conv_only=False)
                normalize_gradient(grad)

                grad_ma.mul_(beta1 ** 2).add_(grad, alpha=1.0 - beta1 ** 2)  # fmt: skip

                step_size: float = self.apply_adam_debias(group['adam_debias'], lr, bias_correction1)

                pn_momentum = grad_ma.mul(1.0 + 1.0).add(neg_grad_ma, alpha=-1.0).mul(1.0 / noise_norm)
                p.addcdiv_(pn_momentum, de_nom, value=-step_size)

        self.lookahead_process_step()

        return loss

    def lookahead_process_step(self):
        self.lookahead_step += 1
        if self.lookahead_step >= self.lookahead_merge_time:
            self.lookahead_step: int = 0
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    state = self.state[p]

                    p.mul_(self.lookahead_blending_alpha).add_(
                        state['lookahead_params'],
                        alpha=1.0 - self.lookahead_blending_alpha,
                    )
                    state['lookahead_params'].copy_(p)

RotoGrad

Bases: RotateOnly

Implementation of RotoGrad as described in the original paper.

Parameters:

Name Type Description Default
backbone Module

nn.Module. shared module.

required
heads Sequence[Module]

List[nn.Module]. task-specific modules.

required
latent_size int

int. size of the shared representation, size of the output of the backbone.z.

required
burn_in_period int

int. When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, :math:{L_k(t)}/{L_k(t_0 = 0)}. This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, :math:t_0 = burn\_in\_period. This is done to overcome problems with losses quickly changing in the first iterations.

20
normalize_losses bool

bool. Whether to use this normalized losses to back-propagate through the task-specific parameters as well.

False
Source code in pytorch_optimizer/optimizer/rotograd.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class RotoGrad(RotateOnly):
    r"""Implementation of RotoGrad as described in the original paper.

    :param backbone: nn.Module. shared module.
    :param heads: List[nn.Module]. task-specific modules.
    :param latent_size: int. size of the shared representation, size of the output of the backbone.z.
    :param burn_in_period: int. When back-propagating towards the shared parameters, *each task loss is normalized
        dividing by its initial value*, :math:`{L_k(t)}/{L_k(t_0 = 0)}`. This parameter sets a number of iterations
        after which the denominator will be replaced by the value of the loss at that iteration, that is,
        :math:`t_0 = burn\_in\_period`. This is done to overcome problems with losses quickly changing
        in the first iterations.
    :param normalize_losses: bool. Whether to use this normalized losses to back-propagate through the task-specific
        parameters as well.
    """

    num_tasks: int
    backbone: nn.Module
    heads: Sequence[nn.Module]
    rep: torch.Tensor

    def __init__(
        self,
        backbone: nn.Module,
        heads: Sequence[nn.Module],
        latent_size: int,
        *args,
        burn_in_period: int = 20,
        normalize_losses: bool = False,
    ):
        super().__init__(backbone, heads, latent_size, burn_in_period, *args, normalize_losses=normalize_losses)

        self.initial_grads = None
        self.counter: int = 0

    def _rep_grad(self):
        super()._rep_grad()

        grad_norms = [torch.linalg.norm(g, keepdim=True).clamp_min(1e-15) for g in self.original_grads]
        if self.initial_grads is None or self.counter == self.burn_in_period:
            self.initial_grads = grad_norms
            conv_ratios = [torch.ones((1,)) for _ in range(len(self.initial_grads))]
        else:
            conv_ratios = [x / y for x, y, in zip(grad_norms, self.initial_grads)]

        self.counter += 1

        alphas = [x / torch.clamp(sum(conv_ratios), 1e-15) for x in conv_ratios]
        weighted_sum_norms = sum(a * g for a, g in zip(alphas, grad_norms))

        return sum(g / n * weighted_sum_norms for g, n in zip(self.original_grads, grad_norms))

SAM

Bases: Optimizer, BaseOptimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization.

Example:

Here's an example::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = SAM(model.parameters(), base_optimizer)

    for input, output in data:
        # first forward-backward pass

        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward pass
        # make sure to do a full forward pass
        loss_function(output, model(input)).backward()
        optimizer.second_step(zero_grad=True)

Alternative example with a single closure-based step function::

    model = YourModel()
    base_optimizer = Ranger21
    optimizer = SAM(model.parameters(), base_optimizer)

    def closure():
        loss = loss_function(output, model(input))
        loss.backward()
        return loss

    for input, output in data:
        loss = loss_function(output, model(input))
        loss.backward()
        optimizer.step(closure)
        optimizer.zero_grad()

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

Optimizer. base optimizer.

required
rho float

float. size of the neighborhood for computing the max loss.

0.05
adaptive bool

bool. element-wise Adaptive SAM.

False
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class SAM(Optimizer, BaseOptimizer):
    r"""Sharpness-Aware Minimization for Efficiently Improving Generalization.

    Example:
    -------
        Here's an example::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = SAM(model.parameters(), base_optimizer)

            for input, output in data:
                # first forward-backward pass

                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.first_step(zero_grad=True)

                # second forward-backward pass
                # make sure to do a full forward pass
                loss_function(output, model(input)).backward()
                optimizer.second_step(zero_grad=True)

        Alternative example with a single closure-based step function::

            model = YourModel()
            base_optimizer = Ranger21
            optimizer = SAM(model.parameters(), base_optimizer)

            def closure():
                loss = loss_function(output, model(input))
                loss.backward()
                return loss

            for input, output in data:
                loss = loss_function(output, model(input))
                loss.backward()
                optimizer.step(closure)
                optimizer.zero_grad()

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: Optimizer. base optimizer.
    :param rho: float. size of the neighborhood for computing the max loss.
    :param adaptive: bool. element-wise Adaptive SAM.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        rho: float = 0.05,
        adaptive: bool = False,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')

        defaults: DEFAULTS = {'rho': rho, 'adaptive': adaptive}
        defaults.update(kwargs)
        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'SAM'

    @torch.no_grad()
    def reset(self):
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        grad_norm = self.grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + 1e-12)

            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['old_p'] = p.clone()
                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)

                # climb to the local maximum "w + e(w)"
                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # get back to "w" from "w + e(w)"
                p.data = self.state[p]['old_p']

        # do the actual "sharpness-aware" update
        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        self.first_step(zero_grad=True)

        # the closure should do a full forward-backward pass
        with torch.enable_grad():
            closure()

        self.second_step()

    def grad_norm(self) -> torch.Tensor:
        # put everything on the same device, in case of model parallelism
        shared_device = self.param_groups[0]['params'][0].device
        return torch.norm(
            torch.stack(
                [
                    ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
                    for group in self.param_groups
                    for p in group['params']
                    if p.grad is not None
                ]
            ),
            p=2,
        )

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

AccSGD

Bases: Optimizer, BaseOptimizer

Accelerating Stochastic Gradient Descent For Least Squares Regression.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
kappa float

float. ratio of long to short step.

1000.0
xi float

float. statistical advantage parameter.

10.0
constant float

float. any small constant under 1.

0.7
weight_decay float

float. weight decay.

0.0
Source code in pytorch_optimizer/optimizer/sgd.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class AccSGD(Optimizer, BaseOptimizer):
    r"""Accelerating Stochastic Gradient Descent For Least Squares Regression.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param kappa: float. ratio of long to short step.
    :param xi: float. statistical advantage parameter.
    :param constant: float. any small constant under 1.
    :param weight_decay: float. weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        kappa: float = 1000.0,
        xi: float = 10.0,
        constant: float = 0.7,
        weight_decay: float = 0.0,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(kappa, 'kappa')
        self.validate_non_negative(xi, 'xi')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_boundary(constant, boundary=1.0, bound_type='upper')

        defaults: DEFAULTS = {
            'lr': lr,
            'kappa': kappa,
            'xi': xi,
            'constant': constant,
            'weight_decay': weight_decay,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'AccSGD'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['momentum_buffer'] = p.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            large_lr: float = group['lr'] * group['kappa'] / group['constant']
            alpha: float = 1.0 - (group['xi'] * (group['constant'] ** 2) / group['kappa'])
            beta: float = 1.0 - alpha
            zeta: float = group['constant'] / (group['constant'] + beta)

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['momentum_buffer'] = p.clone()

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=False,
                    fixed_decay=False,
                )

                buf = state['momentum_buffer']
                buf.mul_((1.0 / beta) - 1.0).add_(grad, alpha=-large_lr).add_(p).mul_(beta)

                p.add_(grad, alpha=-group['lr']).mul_(zeta).add_(buf, alpha=1.0 - zeta)

        return loss

SGDW

Bases: Optimizer, BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.0001
momentum float

float. momentum factor.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
dampening float

float. dampening for momentum.

0.0
nesterov bool

bool. enables Nesterov momentum

False
Source code in pytorch_optimizer/optimizer/sgd.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class SGDW(Optimizer, BaseOptimizer):
    r"""Decoupled Weight Decay Regularization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param dampening: float. dampening for momentum.
    :param nesterov: bool. enables Nesterov momentum
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-4,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        dampening: float = 0.0,
        nesterov: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'dampening': dampening,
            'nesterov': nesterov,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDW'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                if group['momentum'] > 0.0:
                    state['momentum_buffer'] = p.grad.clone()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0 and momentum > 0.0:
                    state['momentum_buffer'] = grad.clone()

                if momentum > 0.0:
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                    if group['nesterov']:
                        grad.add_(buf, alpha=momentum)
                    else:
                        grad = buf

                self.apply_weight_decay(
                    p,
                    grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=False,
                )

                p.add_(grad, alpha=-group['lr'])

        return loss

ASGD

Bases: Optimizer, BaseOptimizer

Adaptive SGD with estimation of the local smoothness (curvature).

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
amplifier float

float. amplifier.

0.02
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
theta float

float. theta.

1.0
dampening float

float. dampening for momentum.

1.0
eps float

float. term added to the denominator to improve numerical stability.

1e-05
Source code in pytorch_optimizer/optimizer/sgd.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
class ASGD(Optimizer, BaseOptimizer):
    r"""Adaptive SGD with estimation of the local smoothness (curvature).

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param amplifier: float. amplifier.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param theta: float. theta.
    :param dampening: float. dampening for momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        amplifier: float = 0.02,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        theta: float = 1.0,
        dampening: float = 1.0,
        eps: float = 1e-5,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(amplifier, 'amplifier')
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'amplifier': amplifier,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'theta': theta,
            'dampening': dampening,
            'eps': eps,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ASGD'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for _ in group['params']:
                pass

    @staticmethod
    def get_norms_by_group(group: Dict, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Get parameter & gradient norm by group."""
        p_norm = torch.zeros(1, dtype=torch.float32, device=device)
        g_norm = torch.zeros(1, dtype=torch.float32, device=device)

        for p in group['params']:
            if p.grad is None:
                continue

            p_norm.add_(p.norm().pow(2))
            g_norm.add_(p.grad.norm().pow(2))

        p_norm.sqrt_()
        g_norm.sqrt_()

        return p_norm, g_norm

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'prev_param_norm' not in group and 'prev_grad_norm' not in group:
                group['prev_param_norm'], group['prev_grad_norm'] = self.get_norms_by_group(
                    group,
                    device=group['params'][0].device,
                )

            group['curr_param_norm'], group['curr_grad_norm'] = self.get_norms_by_group(
                group,
                device=group['params'][0].device,
            )

            param_diff_norm: float = (group['curr_param_norm'] - group['prev_param_norm']).item()
            grad_diff_norm: float = (group['curr_grad_norm'] - group['prev_grad_norm']).item()

            new_lr: float = group['lr'] * math.sqrt(1 + group['amplifier'] * group['theta'])
            if param_diff_norm > 0 and grad_diff_norm > 0:
                new_lr = min(new_lr, param_diff_norm / (group['dampening'] * grad_diff_norm)) + group['eps']

            group['theta'] = new_lr / group['lr']
            group['lr'] = new_lr

            group['prev_param_norm'].copy_(group['curr_param_norm'])
            group['prev_grad_norm'].copy_(group['curr_grad_norm'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                p.add_(grad, alpha=-new_lr)

        return loss

get_norms_by_group(group, device) staticmethod

Get parameter & gradient norm by group.

Source code in pytorch_optimizer/optimizer/sgd.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@staticmethod
def get_norms_by_group(group: Dict, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Get parameter & gradient norm by group."""
    p_norm = torch.zeros(1, dtype=torch.float32, device=device)
    g_norm = torch.zeros(1, dtype=torch.float32, device=device)

    for p in group['params']:
        if p.grad is None:
            continue

        p_norm.add_(p.norm().pow(2))
        g_norm.add_(p.grad.norm().pow(2))

    p_norm.sqrt_()
    g_norm.sqrt_()

    return p_norm, g_norm

SignSGD

Bases: Optimizer, BaseOptimizer

Compressed Optimisation for Non-Convex Problems.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor (0.0 = SignSGD, >0 = Signum).

0.9
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
Source code in pytorch_optimizer/optimizer/sgd.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class SignSGD(Optimizer, BaseOptimizer):
    r"""Compressed Optimisation for Non-Convex Problems.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor (0.0 = SignSGD, >0 = Signum).
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.9,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'beta', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
        }
        super().__init__(params, defaults)

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                if group['momentum'] > 0.0:
                    state['momentum_buffer'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if momentum > 0.0:
                    if len(state) == 0:
                        state['momentum_buffer'] = torch.zeros_like(p)

                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(grad, alpha=1.0 - momentum)
                else:
                    buf = grad

                p.add_(torch.sign(buf), alpha=-group['lr'])

        return loss

SGDP

Bases: Optimizer, BaseOptimizer

SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum factor.

0.0
dampening float

float. dampening for momentum.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
delta float

float. threshold that determines whether a set of parameters is scale invariant or not.

0.1
wd_ratio float

float. relative weight decay applied on scale-invariant parameters compared to that applied on scale-variant parameters.

0.1
nesterov bool

bool. enables nesterov momentum.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-08
Source code in pytorch_optimizer/optimizer/sgdp.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class SGDP(Optimizer, BaseOptimizer):
    r"""SGD + Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum factor.
    :param dampening: float. dampening for momentum.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param delta: float. threshold that determines whether a set of parameters is scale invariant or not.
    :param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
        on scale-variant parameters.
    :param nesterov: bool. enables nesterov momentum.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        dampening: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        delta: float = 0.1,
        wd_ratio: float = 0.1,
        nesterov: bool = False,
        eps: float = 1e-8,
    ):
        self.validate_learning_rate(lr)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_range(wd_ratio, 'wd_ratio', 0.0, 1.0)
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'momentum': momentum,
            'dampening': dampening,
            'delta': delta,
            'wd_ratio': wd_ratio,
            'nesterov': nesterov,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SGDP'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['momentum'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['momentum'] = torch.zeros_like(p)

                buf = state['momentum']
                buf.mul_(momentum).add_(grad, alpha=1.0 - group['dampening'])

                d_p = buf.clone()
                if group['nesterov']:
                    d_p = d_p.mul_(momentum).add_(grad)

                wd_ratio: float = 1.0
                if len(p.shape) > 1:
                    d_p, wd_ratio = projection(
                        p,
                        grad,
                        d_p,
                        group['delta'],
                        group['wd_ratio'],
                        group['eps'],
                    )

                self.apply_weight_decay(
                    p=p,
                    grad=None,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                    ratio=wd_ratio / (1.0 - momentum),
                )

                p.add_(d_p, alpha=-group['lr'])

        return loss

Shampoo

Bases: Optimizer, BaseOptimizer

Preconditioned Stochastic Tensor Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
momentum float

float. momentum.

0.0
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
preconditioning_compute_steps int

int. performance tuning params for controlling memory and compute requirements. How often to compute pre-conditioner.

1
matrix_eps float

float. term added to the denominator to improve numerical stability.

1e-06
Source code in pytorch_optimizer/optimizer/shampoo.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class Shampoo(Optimizer, BaseOptimizer):
    r"""Preconditioned Stochastic Tensor Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. momentum.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
        requirements. How often to compute pre-conditioner.
    :param matrix_eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        preconditioning_compute_steps: int = 1,
        matrix_eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.preconditioning_compute_steps = preconditioning_compute_steps

        defaults: DEFAULTS = {
            'lr': lr,
            'momentum': momentum,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'matrix_eps': matrix_eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Shampoo'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            momentum = group['momentum']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    if momentum > 0.0:
                        state['momentum_buffer'] = grad.clone()

                    for dim_id, dim in enumerate(grad.size()):
                        state[f'pre_cond_{dim_id}'] = group['matrix_eps'] * torch.eye(dim, out=grad.new(dim, dim))
                        state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()

                if momentum > 0.0:
                    grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                order: int = grad.ndimension()
                original_size: int = grad.size()
                for dim_id, dim in enumerate(grad.size()):
                    pre_cond, inv_pre_cond = state[f'pre_cond_{dim_id}'], state[f'inv_pre_cond_{dim_id}']

                    grad = grad.transpose_(0, dim_id).contiguous()
                    transposed_size = grad.size()

                    grad = grad.view(dim, -1)
                    grad_t = grad.t()

                    pre_cond.add_(grad @ grad_t)
                    if group['step'] % self.preconditioning_compute_steps == 0:
                        inv_pre_cond.copy_(compute_power_svd(pre_cond, order))

                    if dim_id == order - 1:
                        grad = grad_t @ inv_pre_cond
                        grad = grad.view(original_size)
                    else:
                        grad = inv_pre_cond @ grad
                        grad = grad.view(transposed_size)

                state['momentum_buffer'] = grad

                p.add_(grad, alpha=-group['lr'])

        return loss

ScalableShampoo

Bases: Optimizer, BaseOptimizer

Scalable Preconditioned Stochastic Tensor Optimization.

This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.

Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
computing G^{-1/p} matrices while the previous one uses SVD which is really slow.

Also, this implementation offers
    1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
    2. not-yet implemented features in the official Pytorch code.
    3. readable, organized, clean code.

Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. beta1, beta2.

(0.9, 0.999)
moving_average_for_momentum bool

bool. perform moving_average for momentum (beta1).

False
weight_decay float

float. weight decay (L2 penalty).

0.0
decoupled_weight_decay bool

bool. use decoupled weight_decay.

False
decoupled_learning_rate bool

bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.

True
inverse_exponent_override int

int. fixed exponent for pre-conditioner, if > 0.

0
start_preconditioning_step int

int.

25
preconditioning_compute_steps int

int. performance tuning params for controlling memory and compute requirements. How often to compute pre-conditioner. Ideally, 1 is the best. However, the current implementation doesn't work on the distributed environment (there are no statistics & pre-conditioners sync among replicas), compute on the GPU (not CPU) and the precision is fp32 (not fp64). Also, followed by the paper, preconditioning_compute_steps does not have a significant effect on the performance. So, If you have a problem with the speed, try to set this step bigger (e.g. 1000).

1000
statistics_compute_steps int

int. How often to compute statistics. usually set to 1 (or 10).

1
block_size int

int. Block size for large layers (if > 0). Block size = 1 ==> AdaGrad (Don't do this, extremely inefficient!) Block size should be as large as feasible under memory/time constraints.

512
skip_preconditioning_rank_lt int

int. Skips preconditioning for parameters with rank less than this value.

1
no_preconditioning_for_layers_with_dim_gt int

int. avoid preconditioning large layers to reduce overall memory.

8192
shape_interpretation bool

bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing statistics [4, 4], [3, 3], [1024, 1024], [512, 512].

True
graft_type int

int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).

SGD
pre_conditioner_type int

int. type of pre-conditioner.

ALL
nesterov bool

bool. Nesterov momentum.

True
diagonal_eps float

float. term added to the denominator to improve numerical stability.

1e-10
matrix_eps float

float. term added to the denominator to improve numerical stability.

1e-06
use_svd bool

bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}. Theoretically, Schur-Newton method is faster than SVD method. However, the inefficiency of the loop code and proper svd kernel, SVD is much faster in some cases (usually in case of small models). see https://github.com/kozistr/pytorch_optimizer/pull/103

False
Source code in pytorch_optimizer/optimizer/shampoo.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class ScalableShampoo(Optimizer, BaseOptimizer):
    r"""Scalable Preconditioned Stochastic Tensor Optimization.

        This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
        or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
        CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.

        Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
        computing G^{-1/p} matrices while the previous one uses SVD which is really slow.

        Also, this implementation offers
            1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
            2. not-yet implemented features in the official Pytorch code.
            3. readable, organized, clean code.

        Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. beta1, beta2.
    :param moving_average_for_momentum: bool. perform moving_average for momentum (beta1).
    :param weight_decay: float. weight decay (L2 penalty).
    :param decoupled_weight_decay: bool. use decoupled weight_decay.
    :param decoupled_learning_rate: bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.
    :param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
    :param start_preconditioning_step: int.
    :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
        requirements. How often to compute pre-conditioner. Ideally, 1 is the best. However, the current implementation
        doesn't work on the distributed environment (there are no statistics & pre-conditioners sync among replicas),
        compute on the GPU (not CPU) and the precision is fp32 (not fp64).
        Also, followed by the paper, `preconditioning_compute_steps` does not have a significant effect on the
        performance. So, If you have a problem with the speed, try to set this step bigger (e.g. 1000).
    :param statistics_compute_steps: int. How often to compute statistics. usually set to 1 (or 10).
    :param block_size: int. Block size for large layers (if > 0).
        Block size = 1 ==> AdaGrad (Don't do this, extremely inefficient!)
        Block size should be as large as feasible under memory/time constraints.
    :param skip_preconditioning_rank_lt: int. Skips preconditioning for parameters with rank less than this value.
    :param no_preconditioning_for_layers_with_dim_gt: int. avoid preconditioning large layers to reduce overall memory.
    :param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
        result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
        statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
    :param graft_type: int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).
    :param pre_conditioner_type: int. type of pre-conditioner.
    :param nesterov: bool. Nesterov momentum.
    :param diagonal_eps: float. term added to the denominator to improve numerical stability.
    :param matrix_eps: float. term added to the denominator to improve numerical stability.
    :param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
        Theoretically, Schur-Newton method is faster than SVD method. However, the inefficiency of the loop code and
        proper svd kernel, SVD is much faster in some cases (usually in case of small models).
        see https://github.com/kozistr/pytorch_optimizer/pull/103
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        moving_average_for_momentum: bool = False,
        weight_decay: float = 0.0,
        decoupled_weight_decay: bool = False,
        decoupled_learning_rate: bool = True,
        inverse_exponent_override: int = 0,
        start_preconditioning_step: int = 25,
        preconditioning_compute_steps: int = 1000,
        statistics_compute_steps: int = 1,
        block_size: int = 512,
        skip_preconditioning_rank_lt: int = 1,
        no_preconditioning_for_layers_with_dim_gt: int = 8192,
        shape_interpretation: bool = True,
        graft_type: int = LayerWiseGrafting.SGD,
        pre_conditioner_type: int = PreConditionerType.ALL,
        nesterov: bool = True,
        diagonal_eps: float = 1e-10,
        matrix_eps: float = 1e-6,
        use_svd: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_step(start_preconditioning_step, 'start_preconditioning_step')
        self.validate_step(preconditioning_compute_steps, 'preconditioning_compute_steps')
        self.validate_step(statistics_compute_steps, 'statistics_compute_steps')
        self.validate_non_negative(diagonal_eps, 'diagonal_eps')
        self.validate_non_negative(matrix_eps, 'matrix_eps')

        self.inverse_exponent_override = inverse_exponent_override
        self.start_preconditioning_step = start_preconditioning_step
        self.preconditioning_compute_steps = preconditioning_compute_steps
        self.statistics_compute_steps = statistics_compute_steps
        self.block_size = block_size
        self.skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
        self.no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
        self.shape_interpretation = shape_interpretation
        self.graft_type = graft_type
        self.pre_conditioner_type = pre_conditioner_type
        self.diagonal_eps = diagonal_eps
        self.matrix_eps = matrix_eps
        self.use_svd = use_svd

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'decoupled_weight_decay': decoupled_weight_decay,
            'decoupled_learning_rate': decoupled_learning_rate,
            'moving_average_for_momentum': moving_average_for_momentum,
            'nesterov': nesterov,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'ScalableShampoo'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['momentum'] = torch.zeros_like(p)
                state['pre_conditioner'] = PreConditioner(
                    p,
                    group['betas'][1],  # beta2
                    self.inverse_exponent_override,
                    self.block_size,
                    self.skip_preconditioning_rank_lt,
                    self.no_preconditioning_for_layers_with_dim_gt,
                    self.shape_interpretation,
                    self.pre_conditioner_type,
                    self.matrix_eps,
                    self.use_svd,
                )
                state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)

    def is_precondition_step(self, step: int) -> bool:
        return step >= self.start_preconditioning_step

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            is_precondition_step: bool = self.is_precondition_step(group['step'])
            pre_conditioner_multiplier: float = 1.0 if group['decoupled_learning_rate'] else group['lr']

            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['momentum'] = torch.zeros_like(p)
                    state['pre_conditioner'] = PreConditioner(
                        p,
                        beta2,
                        self.inverse_exponent_override,
                        self.block_size,
                        self.skip_preconditioning_rank_lt,
                        self.no_preconditioning_for_layers_with_dim_gt,
                        self.shape_interpretation,
                        self.pre_conditioner_type,
                        self.matrix_eps,
                        self.use_svd,
                    )
                    state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)

                pre_conditioner, graft = state['pre_conditioner'], state['graft']

                graft.add_statistics(grad, beta2)
                if group['step'] % self.statistics_compute_steps == 0:
                    pre_conditioner.add_statistics(grad)
                if group['step'] % self.preconditioning_compute_steps == 0:
                    pre_conditioner.compute_pre_conditioners()

                graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
                shampoo_grad: torch.Tensor = (
                    pre_conditioner.preconditioned_grad(grad) if is_precondition_step else grad
                )

                if self.graft_type != LayerWiseGrafting.NONE:
                    graft_norm = torch.linalg.norm(graft_grad)
                    shampoo_norm = torch.linalg.norm(shampoo_grad)

                    shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

                if group['weight_decay'] > 0.0:
                    if not group['decoupled_weight_decay']:
                        graft_grad.add_(p, alpha=group['weight_decay'])
                        shampoo_grad.add_(p, alpha=group['weight_decay'])
                    else:
                        graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
                        shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])

                state['momentum'].mul_(beta1).add_(shampoo_grad)
                graft_momentum = graft.update_momentum(grad, beta1)

                momentum_update = state['momentum'] if is_precondition_step else graft_momentum

                if group['nesterov']:
                    w: float = (1.0 - beta1) if group['moving_average_for_momentum'] else 1.0

                    wd_update = shampoo_grad if is_precondition_step else graft_grad
                    wd_update.mul_(w)

                    momentum_update.mul_(beta1).add_(wd_update)

                p.add_(momentum_update, alpha=-group['lr'])

        return loss

SM3

Bases: Optimizer, BaseOptimizer

Memory-Efficient Adaptive Optimization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.1
momentum float

float. coefficient used to scale prior updates before adding. This drastically increases memory usage if momentum > 0.0. This is ignored if the parameter's gradient is sparse.

0.0
beta float

float. coefficient used for exponential moving averages.

0.0
eps float

float. term added to the denominator to improve numerical stability.

1e-30
Source code in pytorch_optimizer/optimizer/sm3.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class SM3(Optimizer, BaseOptimizer):
    r"""Memory-Efficient Adaptive Optimization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param momentum: float. coefficient used to scale prior updates before adding. This drastically increases
        memory usage if `momentum > 0.0`. This is ignored if the parameter's gradient is sparse.
    :param beta: float. coefficient used for exponential moving averages.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-1,
        momentum: float = 0.0,
        beta: float = 0.0,
        eps: float = 1e-30,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(momentum, 'momentum', 0.0, 1.0)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {'lr': lr, 'momentum': momentum, 'beta': beta, 'eps': eps}
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SM3'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['step'] = 0
                state['momentum_buffer'] = torch.zeros_like(p)

    @staticmethod
    def make_sparse(grad: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        if grad._indices().dim() == 0 or values.dim() == 0:
            return grad.new().resize_as_(grad)
        return grad.new(grad._indices(), values, grad.size())

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum, beta = group['momentum'], group['beta']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad

                shape = grad.shape
                rank: int = len(shape)

                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['momentum_buffer'] = torch.zeros_like(p)

                    if grad.is_sparse:
                        state['accumulator_0'] = torch.zeros(shape[0], dtype=grad.dtype, device=grad.device)
                    elif rank == 0:
                        state['accumulator_0'] = torch.zeros_like(p)
                    else:
                        for i in range(rank):
                            state[f'accumulator_{i}'] = torch.zeros(
                                [1] * i + [shape[i]] + [1] * (rank - 1 - i), dtype=grad.dtype, device=grad.device
                            )

                state['step'] += 1

                if grad.is_sparse:
                    grad = grad.coalesce()

                    acc = state['accumulator_0']
                    update_values = torch.gather(acc, 0, grad._indices()[0])
                    if beta > 0.0:
                        update_values.mul_(beta)
                    update_values.addcmul_(grad._values(), grad._values(), value=1.0 - beta)

                    nu_max = reduce_max_except_dim(self.make_sparse(grad, update_values).to_dense(), 0).squeeze_()

                    if beta > 0.0:
                        torch.max(acc, nu_max, out=acc)
                    else:
                        acc.copy_(nu_max)

                    update_values.add_(group['eps']).rsqrt_().mul_(grad._values())

                    update = self.make_sparse(grad, update_values)
                else:
                    update = state['accumulator_0'].clone()
                    for i in range(1, rank):
                        update = torch.min(update, state[f'accumulator_{i}'])

                    if beta > 0.0:
                        update.mul_(beta)
                    update.addcmul_(grad, grad, value=1.0 - beta)

                    for i in range(rank):
                        acc = state[f'accumulator_{i}']
                        nu_max = reduce_max_except_dim(update, i)
                        if beta > 0.0:
                            torch.max(acc, nu_max, out=acc)
                        else:
                            acc.copy_(nu_max)

                    update.add_(group['eps']).rsqrt_().mul_(grad)

                    if momentum > 0.0:
                        m = state['momentum_buffer']
                        m.mul_(momentum).add_(update, alpha=1.0 - momentum)
                        update = m

                p.add_(update, alpha=-group['lr'])

        return loss

SophiaH

Bases: Optimizer, BaseOptimizer

Second-order Clipped Stochastic Optimization.

Requires `loss.backward(create_graph=True)` in order to calculate hessians.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.06
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.96, 0.99)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
p float

float. clip effective (applied) gradient (p).

0.01
update_period int

int. number of steps after which to apply hessian approximation.

10
num_samples int

int. times to sample z for the approximation of the hessian trace.

1
hessian_distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution to initialize hessian.

'gaussian'
eps float

float. term added to the denominator to improve numerical stability.

1e-12
Source code in pytorch_optimizer/optimizer/sophia.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class SophiaH(Optimizer, BaseOptimizer):
    r"""Second-order Clipped Stochastic Optimization.

        Requires `loss.backward(create_graph=True)` in order to calculate hessians.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param p: float. clip effective (applied) gradient (p).
    :param update_period: int. number of steps after which to apply hessian approximation.
    :param num_samples: int. times to sample `z` for the approximation of the hessian trace.
    :param hessian_distribution: HUTCHINSON_G. type of distribution to initialize hessian.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 6e-2,
        betas: BETAS = (0.96, 0.99),
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        p: float = 1e-2,
        update_period: int = 10,
        num_samples: int = 1,
        hessian_distribution: HUTCHINSON_G = 'gaussian',
        eps: float = 1e-12,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(p, 'p (gradient clip)')
        self.validate_step(update_period, 'update_period')
        self.validate_positive(num_samples, 'num_samples')
        self.validate_options(hessian_distribution, 'hessian_distribution', ['gaussian', 'rademacher'])
        self.validate_non_negative(eps, 'eps')

        self.update_period = update_period
        self.num_samples = num_samples
        self.distribution = hessian_distribution

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'p': p,
            'eps': eps,
        }
        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SophiaH'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]
                state['momentum'] = torch.zeros_like(p)
                state['hessian_moment'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None, hessian: Optional[List[torch.Tensor]] = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        step: int = self.param_groups[0].get('step', 1)

        if hessian is not None:
            self.set_hessian(self.param_groups, self.state, hessian)
        elif step % self.update_period == 0:
            self.zero_hessian(self.param_groups, self.state)
            self.compute_hutchinson_hessian(
                param_groups=self.param_groups,
                state=self.state,
                num_samples=self.num_samples,
                distribution=self.distribution,
            )

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['momentum'] = torch.zeros_like(p)
                    state['hessian_moment'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                momentum, hessian_moment = state['momentum'], state['hessian_moment']
                momentum.mul_(beta1).add_(grad, alpha=1.0 - beta1)

                if 'hessian' in state and (group['step'] % self.update_period == 0 or hessian is not None):
                    hessian_moment.mul_(beta2).add_(state['hessian'], alpha=1.0 - beta2)

                update = (momentum / torch.clip(hessian_moment, min=group['eps'])).clamp_(-group['p'], group['p'])
                p.add_(update, alpha=-group['lr'])

        return loss

SRMM

Bases: Optimizer, BaseOptimizer

Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
beta float

float. adaptivity weight.

0.5
memory_length Optional[int]

Optional[int]. internal memory length for moving average. None for no refreshing.

100
Source code in pytorch_optimizer/optimizer/srmm.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class SRMM(Optimizer, BaseOptimizer):
    """Stochastic regularized majorization-minimization with weakly convex and multi-convex surrogates.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. adaptivity weight.
    :param memory_length: Optional[int]. internal memory length for moving average. None for no refreshing.
    """

    def __init__(self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.5, memory_length: Optional[int] = 100):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[]')

        defaults: DEFAULTS = {'lr': lr, 'beta': beta, 'memory_length': memory_length}
        super().__init__(params, defaults)

        self.base_lrs: List[float] = [group['lr'] for group in self.param_groups]

    def __str__(self) -> str:
        return 'SRMM'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['mov_avg_grad'] = torch.zeros_like(p)
                state['mov_avg_param'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            w_t: float = (
                (group['step'] % (group['memory_length'] if group['memory_length'] is not None else 1)) + 1
            ) ** -group['beta']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]
                if len(state) == 0:
                    state['mov_avg_grad'] = torch.zeros_like(p)
                    state['mov_avg_param'] = torch.zeros_like(p)

                mov_avg_grad, mov_avg_param = state['mov_avg_grad'], state['mov_avg_param']

                mov_avg_grad.mul_(1.0 - w_t).add_(grad, alpha=w_t)
                mov_avg_param.mul_(1.0 - w_t).add_(p, alpha=w_t)

                mov_avg_param.add_(mov_avg_grad, alpha=-group['lr'])

                p.copy_(mov_avg_param)

        return loss

SWATS

Bases: Optimizer, BaseOptimizer

Improving Generalization Performance by Switching from Adam to SGD.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

False
fixed_decay bool

bool. fix weight decay.

False
ams_bound bool

bool. whether to use the ams_bound variant of this algorithm from the paper.

False
nesterov bool

bool. enables Nesterov momentum.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

1e-06
Source code in pytorch_optimizer/optimizer/swats.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class SWATS(Optimizer, BaseOptimizer):
    r"""Improving Generalization Performance by Switching from Adam to SGD.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ams_bound: bool. whether to use the ams_bound variant of this algorithm from the paper.
    :param nesterov: bool. enables Nesterov momentum.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        betas: BETAS = (0.9, 0.999),
        weight_decay: float = 0.0,
        weight_decouple: bool = False,
        fixed_decay: bool = False,
        ams_bound: bool = False,
        nesterov: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-6,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'ams_bound': ams_bound,
            'nesterov': nesterov,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'phase': 'adam',
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'SWATS'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)
                state['exp_avg_sq'] = torch.zeros_like(p)
                state['exp_avg2'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
                if group['ams_bound']:
                    state['max_exp_avg_sq'] = torch.zeros_like(p)
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2: float = 1.0 - beta2 ** group['step']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['exp_avg2'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
                    if group['ams_bound']:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                if group['phase'] == 'sgd':
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(grad)

                    buf = state['momentum_buffer']
                    buf.mul_(beta1).add_(grad)

                    update = buf.clone()
                    update.mul_(1.0 - beta1)

                    if group['nesterov']:
                        update.add_(buf, alpha=beta1)

                    p.add_(update, alpha=-group['lr'])

                    continue

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

                de_nom = self.apply_ams_bound(
                    ams_bound=group['ams_bound'],
                    exp_avg_sq=exp_avg_sq,
                    max_exp_avg_sq=state.get('max_exp_avg_sq', None),
                    eps=group['eps'],
                )

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'],
                    step_size=group['lr'] * math.sqrt(bias_correction2),
                    bias_correction1=bias_correction1,
                )

                perturb = exp_avg.clone()
                perturb.div_(de_nom).mul_(-step_size)

                p.add_(perturb)

                perturb_view = perturb.view(-1)
                pg = perturb_view.dot(grad.view(-1))

                if pg != 0:
                    scaling = perturb_view.dot(perturb_view).div_(-pg)

                    exp_avg2 = state['exp_avg2']
                    exp_avg2.mul_(beta2).add_(scaling, alpha=1.0 - beta2)

                    corrected_exp_avg = exp_avg2 / bias_correction2

                    if (
                        group['step'] > 1
                        and corrected_exp_avg > 0.0
                        and corrected_exp_avg.allclose(scaling, rtol=group['eps'])
                    ):
                        group['phase'] = 'sgd'
                        group['lr'] = corrected_exp_avg.item()

        return loss

Tiger

Bases: Optimizer, BaseOptimizer

A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.001
beta float

float. coefficients used for computing running averages of gradient and the squared hessian trace.

0.965
weight_decay float

float. weight decay (L2 penalty).

0.01
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
Source code in pytorch_optimizer/optimizer/tiger.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class Tiger(Optimizer, BaseOptimizer):
    r"""A Tight-fisted Optimizer, an optimizer that is extremely budget-conscious.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-3,
        beta: float = 0.965,
        weight_decay: float = 0.01,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
    ):
        self.validate_learning_rate(lr)
        self.validate_range(beta, 'beta', 0.0, 1.0, range_type='[)')
        self.validate_non_negative(weight_decay, 'weight_decay')

        defaults: DEFAULTS = {
            'lr': lr,
            'beta': beta,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
        }

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Tiger'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.zeros_like(p)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            beta = group['beta']
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                self.apply_weight_decay(
                    p=p,
                    grad=grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                exp_avg = state['exp_avg']
                exp_avg.mul_(beta).add_(grad, alpha=1.0 - beta)

                p.add_(torch.sign(exp_avg), alpha=-group['lr'])

        return loss

WSAM

Bases: Optimizer, BaseOptimizer

Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

Parameters:

Name Type Description Default
model Union[Module, DistributedDataParallel]

Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make model.no_sync to work.

required
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
base_optimizer OPTIMIZER

Optimizer. base optimizer.

required
rho float

float. size of the neighborhood for computing the max loss.

0.05
gamma float

float. weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.

0.9
adaptive bool

bool. element-wise adaptive SAM.

False
decouple bool

bool. whether to perform a decoupled sharpness regularization.

True
max_norm Optional[float]

Optional[float]. max norm of the gradients.

None
eps float

float. term added to the denominator of WSAM to improve numerical stability.

1e-12
kwargs

Dict. parameters for optimizer.

{}
Source code in pytorch_optimizer/optimizer/sam.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
class WSAM(Optimizer, BaseOptimizer):
    r"""Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term.

    :param model: Union[torch.nn.Module, torch.nn.DataParallel]. the model instance. DDP model is recommended to make
        `model.no_sync` to work.
    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param base_optimizer: Optimizer. base optimizer.
    :param rho: float. size of the neighborhood for computing the max loss.
    :param gamma: float. weighted factor gamma / (1 - gamma) of the sharpness term. 0.8 ~ 0.95 is the optimal.
    :param adaptive: bool. element-wise adaptive SAM.
    :param decouple: bool. whether to perform a decoupled sharpness regularization.
    :param max_norm: Optional[float]. max norm of the gradients.
    :param eps: float. term added to the denominator of WSAM to improve numerical stability.
    :param kwargs: Dict. parameters for optimizer.
    """

    def __init__(
        self,
        model: Union[nn.Module, DistributedDataParallel],
        params: PARAMETERS,
        base_optimizer: OPTIMIZER,
        rho: float = 0.05,
        gamma: float = 0.9,
        adaptive: bool = False,
        decouple: bool = True,
        max_norm: Optional[float] = None,
        eps: float = 1e-12,
        **kwargs,
    ):
        self.validate_non_negative(rho, 'rho')

        self.model = model
        self.decouple = decouple
        self.max_norm = max_norm

        alpha: float = gamma / (1.0 - gamma)

        defaults: DEFAULTS = {'rho': rho, 'alpha': alpha, 'adaptive': adaptive, 'sam_eps': eps}
        defaults.update(kwargs)
        super().__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    def __str__(self) -> str:
        return 'WSAM'

    @torch.no_grad()
    def reset(self):
        pass

    @torch.no_grad()
    def first_step(self, zero_grad: bool = False):
        grad_norm = self.grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + group['sam_eps'])

            for p in group['params']:
                if p.grad is None:
                    continue

                e_w = (torch.pow(p, 2) if group['adaptive'] else 1.0) * p.grad * scale.to(p)

                # climb to the local maximum "w + e(w)"
                p.add_(e_w)

                self.state[p]['e_w'] = e_w

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, op=ReduceOp.AVG)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                self.state[p]['grad'] = p.grad.clone()

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad: bool = False):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if is_initialized():  # pragma: no cover
                    all_reduce(p.grad, ReduceOp.AVG)

                # get back to "w" from "w + e(w)"
                p.add_(self.state[p]['e_w'], alpha=-1.0)

        if self.max_norm is not None:
            clip_grad_norm_(self.model.parameters(), self.max_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                if not self.decouple:
                    p.grad.mul_(group['alpha']).add_(self.state[p]['grad'], alpha=1.0 - group['alpha'])
                else:
                    self.state[p]['sharpness'] = p.grad.clone() - self.state[p]['grad']
                    p.grad.mul_(0.0).add_(self.state[p]['grad'], alpha=1.0)

        # do the actual "sharpness-aware" update
        self.base_optimizer.step()

        if self.decouple:
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is None:
                        continue

                    p.add_(self.state[p]['sharpness'], alpha=-group['lr'] * group['alpha'])

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def step(self, closure: CLOSURE = None):
        if closure is None:
            raise NoClosureError(str(self))

        closure = torch.enable_grad()(closure)

        enable_running_stats(self.model)
        loss = closure()
        self.first_step(zero_grad=True)

        disable_running_stats(self.model)
        closure()
        self.second_step()

        return loss

    def grad_norm(self) -> torch.Tensor:
        shared_device = self.param_groups[0]['params'][0].device
        return torch.norm(
            torch.stack(
                [
                    ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad).norm(p=2).to(shared_device)
                    for group in self.param_groups
                    for p in group['params']
                    if p.grad is not None
                ]
            ),
            p=2,
        )

    def load_state_dict(self, state_dict: Dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

Yogi

Bases: Optimizer, BaseOptimizer

Decoupled Weight Decay Regularization.

Parameters:

Name Type Description Default
params PARAMETERS

PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.

required
lr float

float. learning rate.

0.01
betas BETAS

BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.

(0.9, 0.999)
initial_accumulator float

float. initial values for first and second moments.

1e-06
weight_decay float

float. weight decay (L2 penalty).

0.0
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

True
fixed_decay bool

bool. fix weight decay.

False
r float

float. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
adanorm bool

bool. whether to use the AdaNorm variant.

False
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

False
eps float

float. term added to the denominator to improve numerical stability.

0.001
Source code in pytorch_optimizer/optimizer/yogi.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class Yogi(Optimizer, BaseOptimizer):
    r"""Decoupled Weight Decay Regularization.

    :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
    :param lr: float. learning rate.
    :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
    :param initial_accumulator: float. initial values for first and second moments.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param eps: float. term added to the denominator to improve numerical stability.
    """

    def __init__(
        self,
        params: PARAMETERS,
        lr: float = 1e-2,
        betas: BETAS = (0.9, 0.999),
        initial_accumulator: float = 1e-6,
        weight_decay: float = 0.0,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        r: float = 0.95,
        adanorm: bool = False,
        adam_debias: bool = False,
        eps: float = 1e-3,
    ):
        self.validate_learning_rate(lr)
        self.validate_betas(betas)
        self.validate_non_negative(weight_decay, 'weight_decay')
        self.validate_non_negative(eps, 'eps')

        defaults: DEFAULTS = {
            'lr': lr,
            'betas': betas,
            'weight_decay': weight_decay,
            'weight_decouple': weight_decouple,
            'fixed_decay': fixed_decay,
            'initial_accumulator': initial_accumulator,
            'adanorm': adanorm,
            'adam_debias': adam_debias,
            'eps': eps,
        }
        if adanorm:
            defaults.update({'r': r})

        super().__init__(params, defaults)

    def __str__(self) -> str:
        return 'Yogi'

    @torch.no_grad()
    def reset(self):
        for group in self.param_groups:
            group['step'] = 0
            for p in group['params']:
                state = self.state[p]

                state['exp_avg'] = torch.full_like(p, fill_value=group['initial_accumulator'])
                state['exp_avg_sq'] = torch.full_like(p, fill_value=group['initial_accumulator'])
                if group['adanorm']:
                    state['exp_grad_norm'] = torch.zeros((1,), dtype=p.dtype, device=p.device)

    @torch.no_grad()
    def step(self, closure: CLOSURE = None) -> LOSS:
        loss: LOSS = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            if 'step' in group:
                group['step'] += 1
            else:
                group['step'] = 1

            beta1, beta2 = group['betas']

            bias_correction1: float = 1.0 - beta1 ** group['step']
            bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** group['step'])

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                if grad.is_sparse:
                    raise NoSparseGradientError(str(self))

                state = self.state[p]

                if len(state) == 0:
                    state['exp_avg'] = torch.full_like(p, fill_value=group['initial_accumulator'])
                    state['exp_avg_sq'] = torch.full_like(p, fill_value=group['initial_accumulator'])
                    if group['adanorm']:
                        state['exp_grad_norm'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)

                self.apply_weight_decay(
                    p=p,
                    grad=p.grad,
                    lr=group['lr'],
                    weight_decay=group['weight_decay'],
                    weight_decouple=group['weight_decouple'],
                    fixed_decay=group['fixed_decay'],
                )

                s_grad = self.get_adanorm_gradient(
                    grad=grad,
                    adanorm=group['adanorm'],
                    exp_grad_norm=state.get('exp_grad_norm', None),
                    r=group.get('r', None),
                )

                grad_p2 = grad.mul(grad)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                exp_avg.mul_(beta1).add_(s_grad, alpha=1.0 - beta1)
                exp_avg_sq.addcmul_((exp_avg_sq - grad_p2).sign_(), grad_p2, value=-(1.0 - beta2))

                de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

                step_size: float = self.apply_adam_debias(
                    adam_debias=group['adam_debias'], step_size=group['lr'], bias_correction1=bias_correction1
                )

                p.addcdiv_(exp_avg, de_nom, value=-step_size)

        return loss