Skip to content

Base

BaseOptimizer

Bases: ABC, Optimizer

Base optimizer class. Provides common functionalities for the optimizers.

Source code in pytorch_optimizer/base/optimizer.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class BaseOptimizer(ABC, Optimizer):
    r"""Base optimizer class. Provides common functionalities for the optimizers."""

    def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None:
        super().__init__(params, defaults)

    @staticmethod
    def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimizer:
        r"""Build torch.optim.Optimizer class."""
        if isinstance(optimizer, Optimizer):
            return optimizer

        if 'params' in kwargs:
            params = kwargs.pop('params')
            return optimizer(params, **kwargs)

        raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')

    @staticmethod
    @torch.no_grad()
    def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:
        r"""Set hessian to state from external source. Generally useful when using functorch as a base.

        Example:
        -------
            Here's an example::

                # Hutchinson's Estimator using HVP
                noise = tree_map(lambda v: torch.randn_like(v), params)
                loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
                hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

                optimizer.set_hessian(hessian_diag_est)
                # OR
                optimizer.step(hessian=hessian_diag_est)

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param hessian: List[torch.Tensor]. sequence of hessian to set.
        """
        i: int = 0
        for group in param_groups:
            for p in group['params']:
                if p.size() != hessian[i].size():
                    raise ValueError(
                        f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                    )

                state[p]['hessian'] = hessian[i]
                i += 1

    @staticmethod
    def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True) -> None:
        r"""Zero-out hessian.

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param pre_zero: bool. zero-out hessian before computing the hessian.
        """
        for group in param_groups:
            for p in group['params']:
                if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                    if 'hessian' not in state[p]:
                        state[p]['hessian'] = torch.zeros_like(p)
                    elif pre_zero:
                        state[p]['hessian'].zero_()

    @staticmethod
    @torch.no_grad()
    def compute_hutchinson_hessian(
        param_groups: PARAMETERS,
        state: STATE,
        num_samples: int = 1,
        alpha: float = 1.0,
        distribution: HUTCHINSON_G = 'gaussian',
    ) -> None:
        r"""Hutchinson's approximate hessian, added to the state under key `hessian`.

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
        :param alpha: float. alpha.
        :param distribution: HUTCHINSON_G. type of distribution.
        """
        if distribution not in ('gaussian', 'rademacher'):
            raise NotImplementedError(f'hessian with distribution {distribution} is not implemented.')

        params: List[torch.Tensor] = [
            p
            for group in param_groups
            for p in group['params']
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse
        ]
        if len(params) == 0:
            return

        grads = [p.grad for p in params]

        for i in range(num_samples):
            if distribution == 'rademacher':
                zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
            else:
                zs = [torch.randn_like(p) for p in params]

            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

    @staticmethod
    def apply_weight_decay(
        p: torch.Tensor,
        grad: Optional[torch.Tensor],
        lr: float,
        weight_decay: float,
        weight_decouple: bool,
        fixed_decay: bool,
        ratio: Optional[float] = None,
    ) -> None:
        r"""Apply weight decay.

        :param p: torch.Tensor. parameter.
        :param grad: torch.Tensor. gradient.
        :param lr: float. learning rate.
        :param weight_decay: float. weight decay (L2 penalty).
        :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
        :param fixed_decay: bool. fix weight decay.
        :param ratio: Optional[float]. scale weight decay.
        """
        if weight_decouple:
            p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
        elif weight_decay > 0.0 and grad is not None:
            grad.add_(p, alpha=weight_decay)

    @staticmethod
    def apply_ams_bound(
        ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
    ) -> torch.Tensor:
        r"""Apply AMSBound variant.

        :param ams_bound: bool. whether to apply AMSBound.
        :param exp_avg_sq: torch.Tensor. exp_avg_sq.
        :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
        :param eps: float. epsilon.
        """
        if ams_bound:
            if torch.is_complex(max_exp_avg_sq):
                max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)

            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            de_nom = max_exp_avg_sq.add(eps)
        else:
            de_nom = exp_avg_sq.add(eps)

        return de_nom.sqrt_().add_(eps)

    @staticmethod
    def debias(beta: float, step: int) -> float:
        r"""Adam-style debias correction. Returns `1.0 - beta ** step`.

        :param beta: float. beta.
        :param step: int. number of step.
        """
        return 1.0 - math.pow(beta, step)  # fmt: skip

    @staticmethod
    def debias_beta(beta: float, step: int) -> float:
        r"""Apply the Adam-style debias correction into beta.

        Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

        :param beta: float. beta.
        :param step: int. number of step.
        """
        beta_n: float = math.pow(beta, step)
        return (beta_n - beta) / (beta_n - 1.0)  # fmt: skip

    @staticmethod
    def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
        r"""Apply AdamD variant.

        :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
        :param step_size: float. step size.
        :param bias_correction1: float. bias_correction.
        """
        return step_size if adam_debias else step_size / bias_correction1

    @staticmethod
    def get_rectify_step_size(
        is_rectify: bool,
        step: int,
        lr: float,
        beta2: float,
        n_sma_threshold: int,
        degenerated_to_sgd: bool,
    ) -> Tuple[float, float]:
        r"""Get step size for rectify optimizer.

        :param is_rectify: bool. whether to apply rectify-variant.
        :param step: int. number of steps.
        :param lr: float. learning rate.
        :param beta2: float. beta2.
        :param n_sma_threshold: float. SMA threshold.
        :param degenerated_to_sgd: bool. degenerated to SGD.
        """
        step_size: float = lr
        n_sma: float = 0.0

        if is_rectify:
            n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
            beta2_t: float = beta2 ** step  # fmt: skip
            n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

            if n_sma >= n_sma_threshold:
                rt = math.sqrt(
                    (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
                )
            elif degenerated_to_sgd:
                rt = 1.0
            else:
                rt = -1.0

            step_size *= rt

        return step_size, n_sma

    @staticmethod
    def get_adanorm_gradient(
        grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
    ) -> torch.Tensor:
        r"""Get AdaNorm gradient.

        :param grad: torch.Tensor. gradient.
        :param adanorm: bool. whether to use the AdaNorm variant.
        :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
        :param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.
        """
        if not adanorm or exp_grad_norm is None:
            return grad

        if r is None:
            r = 0.95

        grad_norm = torch.linalg.norm(grad)

        exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r)

        return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

    @staticmethod
    def get_rms(x: torch.Tensor) -> float:
        r"""Get RMS."""
        return x.norm(2) / math.sqrt(x.numel())

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: torch.Tensor,
        exp_avg_sq_col: torch.Tensor,
        output: torch.Tensor,
    ) -> None:
        r"""Get approximation of EMA of squared gradient."""
        r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)

    @staticmethod
    def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
        r"""Apply the Cautious Optimizer feature.

        :param update: torch.Tensor. update. it'll be masked in in-place manner.
        :param grad: torch.Tensor. gradient.
        """
        mask = (update * grad > 0).to(grad.dtype)
        mask.mul_(mask.numel() / (mask.sum() + 1))
        update.mul_(mask)

    @staticmethod
    def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: float = 1e-16) -> float:
        r"""Get StableAdamW RMS.

        :param grad: torch.Tensor. gradient.
        :param exp_avg_sq: torch.Tensor. exp_avg_sq.
        :param eps: float. epsilon.
        """
        return grad.pow(2).div_(exp_avg_sq.clip(min=eps)).mean().sqrt_().clip_(min=1.0).item()

    @staticmethod
    def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
        if range_type == '[)' and not low <= x < high:
            raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
        if range_type == '[]' and not low <= x <= high:
            raise ValueError(f'[-] {name} must be in the range [{low}, {high}]')
        if range_type == '(]' and not low < x <= high:
            raise ValueError(f'[-] {name} must be in the range ({low}, {high}]')
        if range_type == '()' and not low < x < high:
            raise ValueError(f'[-] {name} must be in the range ({low}, {high})')

    @staticmethod
    def validate_non_negative(x: Optional[float], name: str) -> None:
        if x is not None and x < 0.0:
            raise ValueError(f'[-] {name} must be non-negative')

    @staticmethod
    def validate_non_positive(x: Optional[float], name: str) -> None:
        if x is not None and x > 0.0:
            raise ValueError(f'[-] {name} must be non-positive')

    @staticmethod
    def validate_positive(x: Union[float, int], name: str) -> None:
        if x <= 0:
            raise ValueError(f'[-] {name} must be positive')

    @staticmethod
    def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
        if bound_type == 'upper' and constant > boundary:
            raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
        if bound_type == 'lower' and constant < boundary:
            raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')

    @staticmethod
    def validate_step(step: int, step_type: str) -> None:
        if step < 1:
            raise NegativeStepError(step, step_type=step_type)

    @staticmethod
    def validate_options(x: str, name: str, options: List[str]) -> None:
        if x not in options:
            opts: str = ' or '.join([f"'{option}'" for option in options]).strip()
            raise ValueError(f'[-] {name} {x} must be one of ({opts})')

    @staticmethod
    def validate_learning_rate(learning_rate: Optional[float]) -> None:
        if learning_rate is not None and learning_rate < 0.0:
            raise NegativeLRError(learning_rate)

    @staticmethod
    def validate_mod(x: int, y: int) -> None:
        if x % y != 0:
            raise ValueError(f'[-] {x} must be divisible by {y}')

    def validate_betas(self, betas: BETAS, beta_range_type: str = '[)', beta3_range_type: str = '[]') -> None:
        if betas[0] is not None:
            self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type=beta_range_type)

        self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type=beta_range_type)

        if len(betas) < 3:
            return

        if betas[2] is not None:
            self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type=beta3_range_type)

    def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
        if isinstance(nus, float):
            self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')
        else:
            self.validate_range(nus[0], 'nu1', 0.0, 1.0, range_type='[]')
            self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')

    @abstractmethod
    def init_group(self, group: GROUP, **kwargs) -> None:  # pragma: no cover
        r"""Initialize the group of the optimizer and return is_complex."""
        return

    @staticmethod
    def view_as_real(param, *state_and_grads) -> tuple:
        r"""View imaginary tensors as real tensors."""
        if torch.is_complex(param):
            param = torch.view_as_real(param)
            state_and_grads = tuple(
                torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None
                for s in state_and_grads
            )

        return param, *state_and_grads

    @staticmethod
    def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
        r"""Maximize the objective with respect to the params, instead of minimizing."""
        if maximize:
            grad.neg_()

    def step(self, closure: CLOSURE = None) -> LOSS:  # pragma: no cover
        raise NotImplementedError

apply_adam_debias(adam_debias, step_size, bias_correction1) staticmethod

Apply AdamD variant.

Parameters:

Name Type Description Default
adam_debias bool

bool. Only correct the denominator to avoid inflating step sizes early in training.

required
step_size float

float. step size.

required
bias_correction1 float

float. bias_correction.

required
Source code in pytorch_optimizer/base/optimizer.py
198
199
200
201
202
203
204
205
206
@staticmethod
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
    r"""Apply AdamD variant.

    :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
    :param step_size: float. step size.
    :param bias_correction1: float. bias_correction.
    """
    return step_size if adam_debias else step_size / bias_correction1

apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) staticmethod

Apply AMSBound variant.

Parameters:

Name Type Description Default
ams_bound bool

bool. whether to apply AMSBound.

required
exp_avg_sq Tensor

torch.Tensor. exp_avg_sq.

required
max_exp_avg_sq Optional[Tensor]

Optional[torch.Tensor]. max_exp_avg_sq.

required
eps float

float. epsilon.

required
Source code in pytorch_optimizer/base/optimizer.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@staticmethod
def apply_ams_bound(
    ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
) -> torch.Tensor:
    r"""Apply AMSBound variant.

    :param ams_bound: bool. whether to apply AMSBound.
    :param exp_avg_sq: torch.Tensor. exp_avg_sq.
    :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
    :param eps: float. epsilon.
    """
    if ams_bound:
        if torch.is_complex(max_exp_avg_sq):
            max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)

        torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
        de_nom = max_exp_avg_sq.add(eps)
    else:
        de_nom = exp_avg_sq.add(eps)

    return de_nom.sqrt_().add_(eps)

apply_cautious(update, grad) staticmethod

Apply the Cautious Optimizer feature.

Parameters:

Name Type Description Default
update Tensor

torch.Tensor. update. it'll be masked in in-place manner.

required
grad Tensor

torch.Tensor. gradient.

required
Source code in pytorch_optimizer/base/optimizer.py
286
287
288
289
290
291
292
293
294
295
@staticmethod
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
    r"""Apply the Cautious Optimizer feature.

    :param update: torch.Tensor. update. it'll be masked in in-place manner.
    :param grad: torch.Tensor. gradient.
    """
    mask = (update * grad > 0).to(grad.dtype)
    mask.mul_(mask.numel() / (mask.sum() + 1))
    update.mul_(mask)

apply_weight_decay(p, grad, lr, weight_decay, weight_decouple, fixed_decay, ratio=None) staticmethod

Apply weight decay.

Parameters:

Name Type Description Default
p Tensor

torch.Tensor. parameter.

required
grad Optional[Tensor]

torch.Tensor. gradient.

required
lr float

float. learning rate.

required
weight_decay float

float. weight decay (L2 penalty).

required
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

required
fixed_decay bool

bool. fix weight decay.

required
ratio Optional[float]

Optional[float]. scale weight decay.

None
Source code in pytorch_optimizer/base/optimizer.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@staticmethod
def apply_weight_decay(
    p: torch.Tensor,
    grad: Optional[torch.Tensor],
    lr: float,
    weight_decay: float,
    weight_decouple: bool,
    fixed_decay: bool,
    ratio: Optional[float] = None,
) -> None:
    r"""Apply weight decay.

    :param p: torch.Tensor. parameter.
    :param grad: torch.Tensor. gradient.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ratio: Optional[float]. scale weight decay.
    """
    if weight_decouple:
        p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
    elif weight_decay > 0.0 and grad is not None:
        grad.add_(p, alpha=weight_decay)

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/base/optimizer.py
275
276
277
278
279
280
281
282
283
284
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: torch.Tensor,
    exp_avg_sq_col: torch.Tensor,
    output: torch.Tensor,
) -> None:
    r"""Get approximation of EMA of squared gradient."""
    r_factor: torch.Tensor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
    c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
    torch.mul(r_factor, c_factor, out=output)

compute_hutchinson_hessian(param_groups, state, num_samples=1, alpha=1.0, distribution='gaussian') staticmethod

Hutchinson's approximate hessian, added to the state under key hessian.

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
num_samples int

int. number of times to sample z for the approximation of the hessian trace.

1
alpha float

float. alpha.

1.0
distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution.

'gaussian'
Source code in pytorch_optimizer/base/optimizer.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@staticmethod
@torch.no_grad()
def compute_hutchinson_hessian(
    param_groups: PARAMETERS,
    state: STATE,
    num_samples: int = 1,
    alpha: float = 1.0,
    distribution: HUTCHINSON_G = 'gaussian',
) -> None:
    r"""Hutchinson's approximate hessian, added to the state under key `hessian`.

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
    :param alpha: float. alpha.
    :param distribution: HUTCHINSON_G. type of distribution.
    """
    if distribution not in ('gaussian', 'rademacher'):
        raise NotImplementedError(f'hessian with distribution {distribution} is not implemented.')

    params: List[torch.Tensor] = [
        p
        for group in param_groups
        for p in group['params']
        if p.requires_grad and p.grad is not None and not p.grad.is_sparse
    ]
    if len(params) == 0:
        return

    grads = [p.grad for p in params]

    for i in range(num_samples):
        if distribution == 'rademacher':
            zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
        else:
            zs = [torch.randn_like(p) for p in params]

        h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
        for h_z, z, p in zip(h_zs, zs, params):
            state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

debias(beta, step) staticmethod

Adam-style debias correction. Returns 1.0 - beta ** step.

Parameters:

Name Type Description Default
beta float

float. beta.

required
step int

int. number of step.

required
Source code in pytorch_optimizer/base/optimizer.py
177
178
179
180
181
182
183
184
@staticmethod
def debias(beta: float, step: int) -> float:
    r"""Adam-style debias correction. Returns `1.0 - beta ** step`.

    :param beta: float. beta.
    :param step: int. number of step.
    """
    return 1.0 - math.pow(beta, step)  # fmt: skip

debias_beta(beta, step) staticmethod

Apply the Adam-style debias correction into beta.

Simplified version of \^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)

Parameters:

Name Type Description Default
beta float

float. beta.

required
step int

int. number of step.

required
Source code in pytorch_optimizer/base/optimizer.py
186
187
188
189
190
191
192
193
194
195
196
@staticmethod
def debias_beta(beta: float, step: int) -> float:
    r"""Apply the Adam-style debias correction into beta.

    Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

    :param beta: float. beta.
    :param step: int. number of step.
    """
    beta_n: float = math.pow(beta, step)
    return (beta_n - beta) / (beta_n - 1.0)  # fmt: skip

get_adanorm_gradient(grad, adanorm, exp_grad_norm=None, r=0.95) staticmethod

Get AdaNorm gradient.

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. gradient.

required
adanorm bool

bool. whether to use the AdaNorm variant.

required
exp_grad_norm Optional[Tensor]

Optional[torch.Tensor]. exp_grad_norm.

None
r Optional[float]

Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.

0.95
Source code in pytorch_optimizer/base/optimizer.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
@staticmethod
def get_adanorm_gradient(
    grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
) -> torch.Tensor:
    r"""Get AdaNorm gradient.

    :param grad: torch.Tensor. gradient.
    :param adanorm: bool. whether to use the AdaNorm variant.
    :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
    :param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred.
    """
    if not adanorm or exp_grad_norm is None:
        return grad

    if r is None:
        r = 0.95

    grad_norm = torch.linalg.norm(grad)

    exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r)

    return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

get_rectify_step_size(is_rectify, step, lr, beta2, n_sma_threshold, degenerated_to_sgd) staticmethod

Get step size for rectify optimizer.

Parameters:

Name Type Description Default
is_rectify bool

bool. whether to apply rectify-variant.

required
step int

int. number of steps.

required
lr float

float. learning rate.

required
beta2 float

float. beta2.

required
n_sma_threshold int

float. SMA threshold.

required
degenerated_to_sgd bool

bool. degenerated to SGD.

required
Source code in pytorch_optimizer/base/optimizer.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@staticmethod
def get_rectify_step_size(
    is_rectify: bool,
    step: int,
    lr: float,
    beta2: float,
    n_sma_threshold: int,
    degenerated_to_sgd: bool,
) -> Tuple[float, float]:
    r"""Get step size for rectify optimizer.

    :param is_rectify: bool. whether to apply rectify-variant.
    :param step: int. number of steps.
    :param lr: float. learning rate.
    :param beta2: float. beta2.
    :param n_sma_threshold: float. SMA threshold.
    :param degenerated_to_sgd: bool. degenerated to SGD.
    """
    step_size: float = lr
    n_sma: float = 0.0

    if is_rectify:
        n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
        beta2_t: float = beta2 ** step  # fmt: skip
        n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

        if n_sma >= n_sma_threshold:
            rt = math.sqrt(
                (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
            )
        elif degenerated_to_sgd:
            rt = 1.0
        else:
            rt = -1.0

        step_size *= rt

    return step_size, n_sma

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/base/optimizer.py
270
271
272
273
@staticmethod
def get_rms(x: torch.Tensor) -> float:
    r"""Get RMS."""
    return x.norm(2) / math.sqrt(x.numel())

get_stable_adamw_rms(grad, exp_avg_sq, eps=1e-16) staticmethod

Get StableAdamW RMS.

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. gradient.

required
exp_avg_sq Tensor

torch.Tensor. exp_avg_sq.

required
eps float

float. epsilon.

1e-16
Source code in pytorch_optimizer/base/optimizer.py
297
298
299
300
301
302
303
304
305
@staticmethod
def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: float = 1e-16) -> float:
    r"""Get StableAdamW RMS.

    :param grad: torch.Tensor. gradient.
    :param exp_avg_sq: torch.Tensor. exp_avg_sq.
    :param eps: float. epsilon.
    """
    return grad.pow(2).div_(exp_avg_sq.clip(min=eps)).mean().sqrt_().clip_(min=1.0).item()

init_group(group, **kwargs) abstractmethod

Initialize the group of the optimizer and return is_complex.

Source code in pytorch_optimizer/base/optimizer.py
380
381
382
383
@abstractmethod
def init_group(self, group: GROUP, **kwargs) -> None:  # pragma: no cover
    r"""Initialize the group of the optimizer and return is_complex."""
    return

load_optimizer(optimizer, **kwargs) staticmethod

Build torch.optim.Optimizer class.

Source code in pytorch_optimizer/base/optimizer.py
28
29
30
31
32
33
34
35
36
37
38
@staticmethod
def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimizer:
    r"""Build torch.optim.Optimizer class."""
    if isinstance(optimizer, Optimizer):
        return optimizer

    if 'params' in kwargs:
        params = kwargs.pop('params')
        return optimizer(params, **kwargs)

    raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')

maximize_gradient(grad, maximize=False) staticmethod

Maximize the objective with respect to the params, instead of minimizing.

Source code in pytorch_optimizer/base/optimizer.py
397
398
399
400
401
@staticmethod
def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
    r"""Maximize the objective with respect to the params, instead of minimizing."""
    if maximize:
        grad.neg_()

set_hessian(param_groups, state, hessian) staticmethod

Set hessian to state from external source. Generally useful when using functorch as a base.

Example:
Here's an example::

    # Hutchinson's Estimator using HVP
    noise = tree_map(lambda v: torch.randn_like(v), params)
    loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
    hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

    optimizer.set_hessian(hessian_diag_est)
    # OR
    optimizer.step(hessian=hessian_diag_est)

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
hessian List[Tensor]

List[torch.Tensor]. sequence of hessian to set.

required
Source code in pytorch_optimizer/base/optimizer.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@staticmethod
@torch.no_grad()
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:
    r"""Set hessian to state from external source. Generally useful when using functorch as a base.

    Example:
    -------
        Here's an example::

            # Hutchinson's Estimator using HVP
            noise = tree_map(lambda v: torch.randn_like(v), params)
            loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
            hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

            optimizer.set_hessian(hessian_diag_est)
            # OR
            optimizer.step(hessian=hessian_diag_est)

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param hessian: List[torch.Tensor]. sequence of hessian to set.
    """
    i: int = 0
    for group in param_groups:
        for p in group['params']:
            if p.size() != hessian[i].size():
                raise ValueError(
                    f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                )

            state[p]['hessian'] = hessian[i]
            i += 1

view_as_real(param, *state_and_grads) staticmethod

View imaginary tensors as real tensors.

Source code in pytorch_optimizer/base/optimizer.py
385
386
387
388
389
390
391
392
393
394
395
@staticmethod
def view_as_real(param, *state_and_grads) -> tuple:
    r"""View imaginary tensors as real tensors."""
    if torch.is_complex(param):
        param = torch.view_as_real(param)
        state_and_grads = tuple(
            torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None
            for s in state_and_grads
        )

    return param, *state_and_grads

zero_hessian(param_groups, state, pre_zero=True) staticmethod

Zero-out hessian.

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
pre_zero bool

bool. zero-out hessian before computing the hessian.

True
Source code in pytorch_optimizer/base/optimizer.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@staticmethod
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True) -> None:
    r"""Zero-out hessian.

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param pre_zero: bool. zero-out hessian before computing the hessian.
    """
    for group in param_groups:
        for p in group['params']:
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                if 'hessian' not in state[p]:
                    state[p]['hessian'] = torch.zeros_like(p)
                elif pre_zero:
                    state[p]['hessian'].zero_()