Skip to content

Base

BaseOptimizer

Bases: ABC, Optimizer

Base optimizer class. Provides common functionalities for the optimizers.

Source code in pytorch_optimizer/base/optimizer.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
class BaseOptimizer(ABC, Optimizer):
    """Base optimizer class. Provides common functionalities for the optimizers."""

    def __init__(self, params: ParamsT, defaults: Defaults) -> None:
        super().__init__(params, defaults)

    @staticmethod
    def load_optimizer(optimizer: OptimizerInstanceOrClass, **kwargs) -> Optimizer:
        """Build torch.optim.Optimizer class."""
        if isinstance(optimizer, Optimizer):
            return optimizer

        if 'params' in kwargs:
            params = kwargs.pop('params')
            return optimizer(params, **kwargs)

        raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')

    @staticmethod
    @torch.no_grad()
    def set_hessian(param_groups: ParamsT, state: State, hessian: List[torch.Tensor]) -> None:
        """Set hessian to state from external source. Generally useful when using functorch as a base.

        Args:
            param_groups: PARAMETERS. Parameter groups from optimizer.
            state: STATE. Optimizer state dictionary.
            hessian: List[torch.Tensor]. Sequence of Hessian tensors to set.

        Example:
            # Hutchinson's Estimator using Hessian-vector product (HVP)
            >>> noise = tree_map(lambda v: torch.randn_like(v), params)
            >>> loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
            >>> hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)

            >>> optimizer.set_hessian(hessian_diag_est)
            # OR
            >>> optimizer.step(hessian=hessian_diag_est)

        """
        i: int = 0
        for group in param_groups or []:
            for p in group['params']:
                if p.size() != hessian[i].size():
                    raise ValueError(
                        f'the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                    )

                state[p]['hessian'] = hessian[i]
                i += 1

    @staticmethod
    def zero_hessian(param_groups: ParamsT, state: State, pre_zero: bool = True) -> None:
        """Zero-out Hessian.

        Args:
            param_groups (ParamsT): Parameter groups from the optimizer.
            state (State): Optimizer state dictionary.
            pre_zero (bool): If True, zero-out the Hessian before computing/updating it.

        """
        for group in param_groups or []:
            for p in group['params']:
                if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                    if 'hessian' not in state[p]:
                        state[p]['hessian'] = torch.zeros_like(p)
                    elif pre_zero:
                        state[p]['hessian'].zero_()

    @staticmethod
    @torch.no_grad()
    def compute_hutchinson_hessian(
        param_groups: ParamsT,
        state: State,
        num_samples: int = 1,
        alpha: float = 1.0,
        distribution: HutchinsonG = 'gaussian',
    ) -> None:
        r"""Hutchinson's approximate Hessian, added to the state under key `hessian`.

        Args:
            param_groups (ParamsT): Parameter groups from the optimizer.
            state (State): Optimizer state dictionary.
            num_samples (int): Number of times to sample noise vector `z` for the trace approximation.
            alpha (float): Scaling factor for the Hessian estimate.
            distribution (HutchinsonG): Type of noise distribution used (e.g., Rademacher).

        """
        if distribution not in ('gaussian', 'rademacher'):
            raise NotImplementedError(f'hessian with distribution {distribution} is not implemented.')

        params: List[torch.Tensor] = [
            p
            for group in param_groups or []
            for p in group['params']
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse
        ]
        if len(params) == 0:
            return

        grads = [p.grad for p in params]

        for i in range(num_samples):
            if distribution == 'rademacher':
                zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
            else:
                zs = [torch.randn_like(p) for p in params]

            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

    @staticmethod
    def apply_weight_decay(
        p: torch.Tensor,
        grad: Optional[torch.Tensor],
        lr: float,
        weight_decay: float,
        weight_decouple: bool,
        fixed_decay: bool,
        ratio: Optional[float] = None,
    ) -> None:
        """Apply weight decay in an in-place manner.

        Args:
            p (torch.Tensor): Parameter tensor to apply weight decay to.
            grad (torch.Tensor): Gradient tensor of parameter p.
            lr (float): Learning rate to scale the update.
            weight_decay (float): Weight decay coefficient (L2 penalty).
            weight_decouple (bool): If True, applies decoupled weight decay as in AdamW.
            fixed_decay (bool): If True, fixes weight decay to not depend on learning rate.
            ratio (Optional[float]): Optional scaling factor for weight decay.

        """
        if weight_decouple:
            p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
        elif weight_decay > 0.0 and grad is not None:
            grad.add_(p, alpha=weight_decay)

    @staticmethod
    def apply_cautious_weight_decay(
        p: torch.Tensor,
        update: torch.Tensor,
        lr: float,
        weight_decay: float,
    ) -> None:
        """Apply cautious weight decay (CWD) in an in-place manner.

        Args:
            p (torch.Tensor): Parameter tensor to apply weight decay to.
            update (torch.Tensor): update tensor.
            lr (float): Learning rate to scale the update.
            weight_decay (float): Weight decay coefficient (L2 penalty).

        """
        p.copy_(torch.where(update * p >= 0, p * (1.0 - weight_decay * lr), p))

    @staticmethod
    def apply_ams_bound(
        ams_bound: bool,
        exp_avg_sq: torch.Tensor,
        max_exp_avg_sq: Optional[torch.Tensor],
        eps: float,
        exp_avg_sq_eps: float = 1e-15,
    ) -> torch.Tensor:
        """Apply AMSBound variant.

        Args:
            ams_bound (bool): Whether to apply the AMSBound variant.
            exp_avg_sq (torch.Tensor): Exponential moving average of squared gradients.
            max_exp_avg_sq (Optional[torch.Tensor]): Maximum of all exp_avg_sq elements, for AMSBound.
            eps (float): Small epsilon value for numerical stability.
            exp_avg_sq_eps (float): Epsilon used specifically for numerical stability in exp_avg_sq computations.

        """
        if ams_bound:
            if torch.is_complex(max_exp_avg_sq):
                max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)

            torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            de_nom = max_exp_avg_sq.add(exp_avg_sq_eps)
        else:
            de_nom = exp_avg_sq.add(exp_avg_sq_eps)

        return de_nom.sqrt_().add_(eps)

    @staticmethod
    def debias(beta: float, step: int) -> float:
        """Adam-style debias correction.

        Args:
            beta (float): Exponential decay rate for moment estimates.
            step (int): Current optimization step number.

        """
        return 1.0 - math.pow(beta, step)  # fmt: skip

    @staticmethod
    def debias_beta(beta: float, step: int) -> float:
        r"""Apply the Adam-style debias correction into beta.

        Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

        Args:
            beta (float): The original beta decay rate.
            step (int): Current optimization step number.

        """
        beta_n: float = math.pow(beta, step)
        return (beta_n - beta) / (beta_n - 1.0)  # fmt: skip

    @staticmethod
    def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
        """Apply AdamD variant.

        Args:
            adam_debias (bool): If True, only corrects the denominator to avoid inflating step sizes early in training.
            step_size (float): The step size for the update.
            bias_correction1 (float): The bias correction factor for the first moment.

        """
        return step_size if adam_debias else step_size / bias_correction1

    @staticmethod
    def get_rectify_step_size(
        is_rectify: bool,
        step: int,
        lr: float,
        beta2: float,
        n_sma_threshold: int,
        degenerated_to_sgd: bool,
    ) -> Tuple[float, float]:
        """Get step size for rectify optimizer.

        Args:
            is_rectify (bool): Whether to apply the rectify variant.
            step (int): Current step number.
            lr (float): Base learning rate.
            beta2 (float): Beta2 parameter from optimizer (momentum term).
            n_sma_threshold (float): Simple Moving Average (SMA) threshold for rectification.
            degenerated_to_sgd (bool): Whether to degenerate to SGD if below threshold.

        """
        step_size: float = lr
        n_sma: float = 0.0

        if is_rectify:
            n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
            beta2_t: float = beta2 ** step  # fmt: skip
            n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

            if n_sma >= n_sma_threshold:
                rt = math.sqrt(
                    (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
                )
            elif degenerated_to_sgd:
                rt = 1.0
            else:
                rt = -1.0

            step_size *= rt

        return step_size, n_sma

    @staticmethod
    def get_adanorm_gradient(
        grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
    ) -> torch.Tensor:
        r"""Get AdaNorm gradient.

        Args:
            grad (torch.Tensor): Gradient.
            adanorm (bool): Whether to use the AdaNorm variant.
            exp_grad_norm (Optional[torch.Tensor]): Exponential moving average of gradient norm.
            r (Optional[float]): EMA factor; between 0.9 and 0.99 is preferred.

        """
        if not adanorm or exp_grad_norm is None:
            return grad

        if r is None:
            r = 0.95

        grad_norm = torch.linalg.norm(grad)

        exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r)

        return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

    @staticmethod
    def get_rms(x: Union[List[torch.Tensor], torch.Tensor]) -> Union[List[torch.Tensor], torch.Tensor]:
        """Get RMS."""
        if isinstance(x, torch.Tensor):
            return x.norm(2).div_(math.sqrt(x.numel()))

        factors: List[float] = [math.sqrt(p.numel()) for p in x]
        norms = torch._foreach_norm(x, ord=2)
        torch._foreach_div_(norms, factors)

        return norms  # pyright: ignore[reportReturnType]

    @staticmethod
    def approximate_sq_grad(
        exp_avg_sq_row: Union[List[torch.Tensor], torch.Tensor],
        exp_avg_sq_col: Union[List[torch.Tensor], torch.Tensor],
        output: Union[List[torch.Tensor], torch.Tensor],
    ) -> None:
        """Get approximation of EMA of squared gradient."""
        if isinstance(exp_avg_sq_row, torch.Tensor):
            r_factor: torch.Tensor = (
                (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
            )
            c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
            torch.mul(r_factor, c_factor, out=output)
            return

        row_means = [r.mean(dim=-1, keepdim=True) for r in exp_avg_sq_row]

        r_factors = torch._foreach_div(exp_avg_sq_row, row_means)
        foreach_rsqrt_(r_factors)
        r_factors = [r_factor.unsqueeze(-1) for r_factor in r_factors]

        c_factors = [c_factor.unsqueeze(-2) for c_factor in exp_avg_sq_col]
        foreach_rsqrt_(c_factors)

        torch._foreach_copy_(output, torch._foreach_mul(r_factors, c_factors))

    @staticmethod
    def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
        """Apply the Cautious Optimizer feature.

        Args:
            update (torch.Tensor): Update tensor, masked in-place.
            grad (torch.Tensor): Gradient tensor.

        """
        mask = (update * grad > 0).to(grad.dtype)
        mask.mul_(mask.numel() / (mask.sum() + 1))
        update.mul_(mask)

    @staticmethod
    def can_use_foreach(group: ParamGroup, foreach: Optional[bool]) -> bool:
        """Check if foreach operations can be used for this parameter group.

        Args:
            group (ParamGroup): Parameter group dictionary.
            foreach (Optional[bool]): User-specified foreach preference (None for auto-detect).

        Returns:
            True if foreach operations should be used, False otherwise.

        """
        if foreach is False:
            return False

        has_param: bool = False
        for p in group['params']:
            g = p.grad
            if g is None:
                continue

            has_param = True
            if g.is_sparse or torch.is_complex(p):
                return False

        return has_param

    @staticmethod
    def collect_trainable_params(
        group: ParamGroup,
        state: State,
        state_keys: Optional[List[str]] = None,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], Dict[str, List[torch.Tensor]]]:
        """Collect trainable parameters, gradients, and state tensors from a group.

        Args:
            group: Parameter group dictionary.
            state: Optimizer state dictionary.
            state_keys: List of state keys to collect (e.g., ['exp_avg', 'exp_avg_sq']).

        Returns:
            Tuple containing:
            - params: List of parameter tensors with gradients
            - grads: List of corresponding gradient tensors
            - state_dict: Dictionary mapping state keys to lists of state tensors

        """
        if state_keys is None:
            state_keys = []

        params: List[torch.Tensor] = []
        grads: List[torch.Tensor] = []
        state_dict: Dict[str, List[torch.Tensor]] = {key: [] for key in state_keys}

        for p in group['params']:
            if p.grad is None:
                continue

            params.append(p)
            grads.append(p.grad)

            if state_keys:
                p_state = state[p]
                for key in state_keys:
                    if key in p_state:
                        state_dict[key].append(p_state[key])

        return params, grads, state_dict

    @staticmethod
    def apply_weight_decay_foreach(
        params: List[torch.Tensor],
        grads: List[torch.Tensor],
        lr: Union[List[float], List[torch.Tensor], float, torch.Tensor],
        weight_decay: float,
        weight_decouple: bool,
        fixed_decay: bool,
    ) -> None:
        """Apply weight decay to a list of parameters.

        Args:
            params: List of parameter tensors.
            grads: List of gradient tensors.
            lr: Learning rate.
            weight_decay: Weight decay coefficient.
            weight_decouple: If True, applies decoupled weight decay as in AdamW.
            fixed_decay: If True, fixes weight decay to not depend on learning rate.

        """
        if weight_decay == 0.0:
            return

        if not weight_decouple:
            torch._foreach_add_(grads, params, alpha=weight_decay)
            return

        if fixed_decay:
            factor = 1.0 - weight_decay
        elif isinstance(lr, Sequence):
            factor = torch._foreach_mul(lr, -weight_decay)
            torch._foreach_add_(factor, 1.0)
        else:
            factor = 1.0 - weight_decay * lr

        torch._foreach_mul_(params, factor)

    @staticmethod
    def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: float = 1e-16) -> float:
        """Get StableAdamW RMS.

        Args:
            grad (torch.Tensor): gradient.
            exp_avg_sq (torch.Tensor): Exponential moving average of squared gradient.
            eps (float): Small value to prevent division by zero.

        """
        return grad.pow(2).div_(exp_avg_sq.clip(min=eps)).mean().sqrt_().clip_(min=1.0).item()

    @staticmethod
    def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
        if range_type == '[)' and not low <= x < high:
            raise ValueError(f'{name} must be in the range [{low}, {high})')
        if range_type == '[]' and not low <= x <= high:
            raise ValueError(f'{name} must be in the range [{low}, {high}]')
        if range_type == '(]' and not low < x <= high:
            raise ValueError(f'{name} must be in the range ({low}, {high}]')
        if range_type == '()' and not low < x < high:
            raise ValueError(f'{name} must be in the range ({low}, {high})')

    @staticmethod
    def validate_non_negative(x: Optional[float], name: str) -> None:
        if x is not None and x < 0.0:
            raise ValueError(f'{name} must be non-negative')

    @staticmethod
    def validate_non_positive(x: Optional[float], name: str) -> None:
        if x is not None and x > 0.0:
            raise ValueError(f'{name} must be non-positive')

    @staticmethod
    def validate_positive(x: Union[float, int], name: str) -> None:
        if x <= 0:
            raise ValueError(f'{name} must be positive')

    @staticmethod
    def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper') -> None:
        if bound_type == 'upper' and constant > boundary:
            raise ValueError(f'constant {constant} must be in a range of (-inf, {boundary}]')
        if bound_type == 'lower' and constant < boundary:
            raise ValueError(f'constant {constant} must be in a range of [{boundary}, inf)')

    @staticmethod
    def validate_step(step: int, step_type: str) -> None:
        if step < 1:
            raise NegativeStepError(step, step_type=step_type)

    @staticmethod
    def validate_options(x: str, name: str, options: List[str]) -> None:
        if x not in options:
            opts: str = ' or '.join([f"'{option}'" for option in options]).strip()
            raise ValueError(f'{name} {x} must be one of ({opts})')

    @staticmethod
    def validate_learning_rate(learning_rate: Optional[float]) -> None:
        if learning_rate is not None and learning_rate < 0.0:
            raise NegativeLRError(learning_rate)

    @staticmethod
    def validate_mod(x: int, y: int) -> None:
        if x % y != 0:
            raise ValueError(f'{x} must be divisible by {y}')

    def validate_betas(
        self,
        betas: Union[Betas, Tuple[None, float]],
        beta_range_type: str = '[)',
        beta3_range_type: str = '[]',
    ) -> None:
        if betas[0] is not None:
            self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type=beta_range_type)

        self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type=beta_range_type)

        if len(betas) < 3:
            return

        if betas[2] is not None:
            self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type=beta3_range_type)

    def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None:
        if isinstance(nus, tuple):
            nu1, nu2 = nus
            self.validate_range(nu1, 'nu1', 0.0, 1.0, range_type='[]')
            self.validate_range(nu2, 'nu2', 0.0, 1.0, range_type='[]')
        else:
            self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')

    @abstractmethod
    def init_group(self, group: ParamGroup, **kwargs) -> None:  # pragma: no cover
        """Initialize the group of the optimizer and return is_complex."""
        return

    @staticmethod
    def view_as_real(param, *state_and_grads) -> tuple:
        """View imaginary tensors as real tensors."""
        if torch.is_complex(param):
            param = torch.view_as_real(param)
            state_and_grads = tuple(
                torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None
                for s in state_and_grads
            )

        return param, *state_and_grads

    @staticmethod
    def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
        """Maximize the objective with respect to the params, instead of minimizing."""
        if maximize:
            grad.neg_()

    def step(self, closure: Closure = None) -> Loss:  # pragma: no cover
        raise NotImplementedError

apply_adam_debias(adam_debias, step_size, bias_correction1) staticmethod

Apply AdamD variant.

Parameters:

Name Type Description Default
adam_debias bool

If True, only corrects the denominator to avoid inflating step sizes early in training.

required
step_size float

The step size for the update.

required
bias_correction1 float

The bias correction factor for the first moment.

required
Source code in pytorch_optimizer/base/optimizer.py
233
234
235
236
237
238
239
240
241
242
243
@staticmethod
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
    """Apply AdamD variant.

    Args:
        adam_debias (bool): If True, only corrects the denominator to avoid inflating step sizes early in training.
        step_size (float): The step size for the update.
        bias_correction1 (float): The bias correction factor for the first moment.

    """
    return step_size if adam_debias else step_size / bias_correction1

apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps, exp_avg_sq_eps=1e-15) staticmethod

Apply AMSBound variant.

Parameters:

Name Type Description Default
ams_bound bool

Whether to apply the AMSBound variant.

required
exp_avg_sq Tensor

Exponential moving average of squared gradients.

required
max_exp_avg_sq Optional[Tensor]

Maximum of all exp_avg_sq elements, for AMSBound.

required
eps float

Small epsilon value for numerical stability.

required
exp_avg_sq_eps float

Epsilon used specifically for numerical stability in exp_avg_sq computations.

1e-15
Source code in pytorch_optimizer/base/optimizer.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@staticmethod
def apply_ams_bound(
    ams_bound: bool,
    exp_avg_sq: torch.Tensor,
    max_exp_avg_sq: Optional[torch.Tensor],
    eps: float,
    exp_avg_sq_eps: float = 1e-15,
) -> torch.Tensor:
    """Apply AMSBound variant.

    Args:
        ams_bound (bool): Whether to apply the AMSBound variant.
        exp_avg_sq (torch.Tensor): Exponential moving average of squared gradients.
        max_exp_avg_sq (Optional[torch.Tensor]): Maximum of all exp_avg_sq elements, for AMSBound.
        eps (float): Small epsilon value for numerical stability.
        exp_avg_sq_eps (float): Epsilon used specifically for numerical stability in exp_avg_sq computations.

    """
    if ams_bound:
        if torch.is_complex(max_exp_avg_sq):
            max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq)

        torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
        de_nom = max_exp_avg_sq.add(exp_avg_sq_eps)
    else:
        de_nom = exp_avg_sq.add(exp_avg_sq_eps)

    return de_nom.sqrt_().add_(eps)

apply_cautious(update, grad) staticmethod

Apply the Cautious Optimizer feature.

Parameters:

Name Type Description Default
update Tensor

Update tensor, masked in-place.

required
grad Tensor

Gradient tensor.

required
Source code in pytorch_optimizer/base/optimizer.py
349
350
351
352
353
354
355
356
357
358
359
360
@staticmethod
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
    """Apply the Cautious Optimizer feature.

    Args:
        update (torch.Tensor): Update tensor, masked in-place.
        grad (torch.Tensor): Gradient tensor.

    """
    mask = (update * grad > 0).to(grad.dtype)
    mask.mul_(mask.numel() / (mask.sum() + 1))
    update.mul_(mask)

apply_cautious_weight_decay(p, update, lr, weight_decay) staticmethod

Apply cautious weight decay (CWD) in an in-place manner.

Parameters:

Name Type Description Default
p Tensor

Parameter tensor to apply weight decay to.

required
update Tensor

update tensor.

required
lr float

Learning rate to scale the update.

required
weight_decay float

Weight decay coefficient (L2 penalty).

required
Source code in pytorch_optimizer/base/optimizer.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@staticmethod
def apply_cautious_weight_decay(
    p: torch.Tensor,
    update: torch.Tensor,
    lr: float,
    weight_decay: float,
) -> None:
    """Apply cautious weight decay (CWD) in an in-place manner.

    Args:
        p (torch.Tensor): Parameter tensor to apply weight decay to.
        update (torch.Tensor): update tensor.
        lr (float): Learning rate to scale the update.
        weight_decay (float): Weight decay coefficient (L2 penalty).

    """
    p.copy_(torch.where(update * p >= 0, p * (1.0 - weight_decay * lr), p))

apply_weight_decay(p, grad, lr, weight_decay, weight_decouple, fixed_decay, ratio=None) staticmethod

Apply weight decay in an in-place manner.

Parameters:

Name Type Description Default
p Tensor

Parameter tensor to apply weight decay to.

required
grad Tensor

Gradient tensor of parameter p.

required
lr float

Learning rate to scale the update.

required
weight_decay float

Weight decay coefficient (L2 penalty).

required
weight_decouple bool

If True, applies decoupled weight decay as in AdamW.

required
fixed_decay bool

If True, fixes weight decay to not depend on learning rate.

required
ratio Optional[float]

Optional scaling factor for weight decay.

None
Source code in pytorch_optimizer/base/optimizer.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@staticmethod
def apply_weight_decay(
    p: torch.Tensor,
    grad: Optional[torch.Tensor],
    lr: float,
    weight_decay: float,
    weight_decouple: bool,
    fixed_decay: bool,
    ratio: Optional[float] = None,
) -> None:
    """Apply weight decay in an in-place manner.

    Args:
        p (torch.Tensor): Parameter tensor to apply weight decay to.
        grad (torch.Tensor): Gradient tensor of parameter p.
        lr (float): Learning rate to scale the update.
        weight_decay (float): Weight decay coefficient (L2 penalty).
        weight_decouple (bool): If True, applies decoupled weight decay as in AdamW.
        fixed_decay (bool): If True, fixes weight decay to not depend on learning rate.
        ratio (Optional[float]): Optional scaling factor for weight decay.

    """
    if weight_decouple:
        p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
    elif weight_decay > 0.0 and grad is not None:
        grad.add_(p, alpha=weight_decay)

apply_weight_decay_foreach(params, grads, lr, weight_decay, weight_decouple, fixed_decay) staticmethod

Apply weight decay to a list of parameters.

Parameters:

Name Type Description Default
params List[Tensor]

List of parameter tensors.

required
grads List[Tensor]

List of gradient tensors.

required
lr Union[List[float], List[Tensor], float, Tensor]

Learning rate.

required
weight_decay float

Weight decay coefficient.

required
weight_decouple bool

If True, applies decoupled weight decay as in AdamW.

required
fixed_decay bool

If True, fixes weight decay to not depend on learning rate.

required
Source code in pytorch_optimizer/base/optimizer.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
@staticmethod
def apply_weight_decay_foreach(
    params: List[torch.Tensor],
    grads: List[torch.Tensor],
    lr: Union[List[float], List[torch.Tensor], float, torch.Tensor],
    weight_decay: float,
    weight_decouple: bool,
    fixed_decay: bool,
) -> None:
    """Apply weight decay to a list of parameters.

    Args:
        params: List of parameter tensors.
        grads: List of gradient tensors.
        lr: Learning rate.
        weight_decay: Weight decay coefficient.
        weight_decouple: If True, applies decoupled weight decay as in AdamW.
        fixed_decay: If True, fixes weight decay to not depend on learning rate.

    """
    if weight_decay == 0.0:
        return

    if not weight_decouple:
        torch._foreach_add_(grads, params, alpha=weight_decay)
        return

    if fixed_decay:
        factor = 1.0 - weight_decay
    elif isinstance(lr, Sequence):
        factor = torch._foreach_mul(lr, -weight_decay)
        torch._foreach_add_(factor, 1.0)
    else:
        factor = 1.0 - weight_decay * lr

    torch._foreach_mul_(params, factor)

approximate_sq_grad(exp_avg_sq_row, exp_avg_sq_col, output) staticmethod

Get approximation of EMA of squared gradient.

Source code in pytorch_optimizer/base/optimizer.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
@staticmethod
def approximate_sq_grad(
    exp_avg_sq_row: Union[List[torch.Tensor], torch.Tensor],
    exp_avg_sq_col: Union[List[torch.Tensor], torch.Tensor],
    output: Union[List[torch.Tensor], torch.Tensor],
) -> None:
    """Get approximation of EMA of squared gradient."""
    if isinstance(exp_avg_sq_row, torch.Tensor):
        r_factor: torch.Tensor = (
            (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
        )
        c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
        torch.mul(r_factor, c_factor, out=output)
        return

    row_means = [r.mean(dim=-1, keepdim=True) for r in exp_avg_sq_row]

    r_factors = torch._foreach_div(exp_avg_sq_row, row_means)
    foreach_rsqrt_(r_factors)
    r_factors = [r_factor.unsqueeze(-1) for r_factor in r_factors]

    c_factors = [c_factor.unsqueeze(-2) for c_factor in exp_avg_sq_col]
    foreach_rsqrt_(c_factors)

    torch._foreach_copy_(output, torch._foreach_mul(r_factors, c_factors))

can_use_foreach(group, foreach) staticmethod

Check if foreach operations can be used for this parameter group.

Parameters:

Name Type Description Default
group ParamGroup

Parameter group dictionary.

required
foreach Optional[bool]

User-specified foreach preference (None for auto-detect).

required

Returns:

Type Description
bool

True if foreach operations should be used, False otherwise.

Source code in pytorch_optimizer/base/optimizer.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@staticmethod
def can_use_foreach(group: ParamGroup, foreach: Optional[bool]) -> bool:
    """Check if foreach operations can be used for this parameter group.

    Args:
        group (ParamGroup): Parameter group dictionary.
        foreach (Optional[bool]): User-specified foreach preference (None for auto-detect).

    Returns:
        True if foreach operations should be used, False otherwise.

    """
    if foreach is False:
        return False

    has_param: bool = False
    for p in group['params']:
        g = p.grad
        if g is None:
            continue

        has_param = True
        if g.is_sparse or torch.is_complex(p):
            return False

    return has_param

collect_trainable_params(group, state, state_keys=None) staticmethod

Collect trainable parameters, gradients, and state tensors from a group.

Parameters:

Name Type Description Default
group ParamGroup

Parameter group dictionary.

required
state State

Optimizer state dictionary.

required
state_keys Optional[List[str]]

List of state keys to collect (e.g., ['exp_avg', 'exp_avg_sq']).

None

Returns:

Type Description
List[Tensor]

Tuple containing:

List[Tensor]
  • params: List of parameter tensors with gradients
Dict[str, List[Tensor]]
  • grads: List of corresponding gradient tensors
Tuple[List[Tensor], List[Tensor], Dict[str, List[Tensor]]]
  • state_dict: Dictionary mapping state keys to lists of state tensors
Source code in pytorch_optimizer/base/optimizer.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@staticmethod
def collect_trainable_params(
    group: ParamGroup,
    state: State,
    state_keys: Optional[List[str]] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], Dict[str, List[torch.Tensor]]]:
    """Collect trainable parameters, gradients, and state tensors from a group.

    Args:
        group: Parameter group dictionary.
        state: Optimizer state dictionary.
        state_keys: List of state keys to collect (e.g., ['exp_avg', 'exp_avg_sq']).

    Returns:
        Tuple containing:
        - params: List of parameter tensors with gradients
        - grads: List of corresponding gradient tensors
        - state_dict: Dictionary mapping state keys to lists of state tensors

    """
    if state_keys is None:
        state_keys = []

    params: List[torch.Tensor] = []
    grads: List[torch.Tensor] = []
    state_dict: Dict[str, List[torch.Tensor]] = {key: [] for key in state_keys}

    for p in group['params']:
        if p.grad is None:
            continue

        params.append(p)
        grads.append(p.grad)

        if state_keys:
            p_state = state[p]
            for key in state_keys:
                if key in p_state:
                    state_dict[key].append(p_state[key])

    return params, grads, state_dict

compute_hutchinson_hessian(param_groups, state, num_samples=1, alpha=1.0, distribution='gaussian') staticmethod

Hutchinson's approximate Hessian, added to the state under key hessian.

Parameters:

Name Type Description Default
param_groups ParamsT

Parameter groups from the optimizer.

required
state State

Optimizer state dictionary.

required
num_samples int

Number of times to sample noise vector z for the trace approximation.

1
alpha float

Scaling factor for the Hessian estimate.

1.0
distribution HutchinsonG

Type of noise distribution used (e.g., Rademacher).

'gaussian'
Source code in pytorch_optimizer/base/optimizer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@staticmethod
@torch.no_grad()
def compute_hutchinson_hessian(
    param_groups: ParamsT,
    state: State,
    num_samples: int = 1,
    alpha: float = 1.0,
    distribution: HutchinsonG = 'gaussian',
) -> None:
    r"""Hutchinson's approximate Hessian, added to the state under key `hessian`.

    Args:
        param_groups (ParamsT): Parameter groups from the optimizer.
        state (State): Optimizer state dictionary.
        num_samples (int): Number of times to sample noise vector `z` for the trace approximation.
        alpha (float): Scaling factor for the Hessian estimate.
        distribution (HutchinsonG): Type of noise distribution used (e.g., Rademacher).

    """
    if distribution not in ('gaussian', 'rademacher'):
        raise NotImplementedError(f'hessian with distribution {distribution} is not implemented.')

    params: List[torch.Tensor] = [
        p
        for group in param_groups or []
        for p in group['params']
        if p.requires_grad and p.grad is not None and not p.grad.is_sparse
    ]
    if len(params) == 0:
        return

    grads = [p.grad for p in params]

    for i in range(num_samples):
        if distribution == 'rademacher':
            zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
        else:
            zs = [torch.randn_like(p) for p in params]

        h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
        for h_z, z, p in zip(h_zs, zs, params):
            state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

debias(beta, step) staticmethod

Adam-style debias correction.

Parameters:

Name Type Description Default
beta float

Exponential decay rate for moment estimates.

required
step int

Current optimization step number.

required
Source code in pytorch_optimizer/base/optimizer.py
208
209
210
211
212
213
214
215
216
217
@staticmethod
def debias(beta: float, step: int) -> float:
    """Adam-style debias correction.

    Args:
        beta (float): Exponential decay rate for moment estimates.
        step (int): Current optimization step number.

    """
    return 1.0 - math.pow(beta, step)  # fmt: skip

debias_beta(beta, step) staticmethod

Apply the Adam-style debias correction into beta.

Simplified version of \^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)

Parameters:

Name Type Description Default
beta float

The original beta decay rate.

required
step int

Current optimization step number.

required
Source code in pytorch_optimizer/base/optimizer.py
219
220
221
222
223
224
225
226
227
228
229
230
231
@staticmethod
def debias_beta(beta: float, step: int) -> float:
    r"""Apply the Adam-style debias correction into beta.

    Simplified version of `\^{beta} = beta * (1.0 - beta ** (step - 1)) / (1.0 - beta ** step)`

    Args:
        beta (float): The original beta decay rate.
        step (int): Current optimization step number.

    """
    beta_n: float = math.pow(beta, step)
    return (beta_n - beta) / (beta_n - 1.0)  # fmt: skip

get_adanorm_gradient(grad, adanorm, exp_grad_norm=None, r=0.95) staticmethod

Get AdaNorm gradient.

Parameters:

Name Type Description Default
grad Tensor

Gradient.

required
adanorm bool

Whether to use the AdaNorm variant.

required
exp_grad_norm Optional[Tensor]

Exponential moving average of gradient norm.

None
r Optional[float]

EMA factor; between 0.9 and 0.99 is preferred.

0.95
Source code in pytorch_optimizer/base/optimizer.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
@staticmethod
def get_adanorm_gradient(
    grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
) -> torch.Tensor:
    r"""Get AdaNorm gradient.

    Args:
        grad (torch.Tensor): Gradient.
        adanorm (bool): Whether to use the AdaNorm variant.
        exp_grad_norm (Optional[torch.Tensor]): Exponential moving average of gradient norm.
        r (Optional[float]): EMA factor; between 0.9 and 0.99 is preferred.

    """
    if not adanorm or exp_grad_norm is None:
        return grad

    if r is None:
        r = 0.95

    grad_norm = torch.linalg.norm(grad)

    exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r)

    return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad

get_rectify_step_size(is_rectify, step, lr, beta2, n_sma_threshold, degenerated_to_sgd) staticmethod

Get step size for rectify optimizer.

Parameters:

Name Type Description Default
is_rectify bool

Whether to apply the rectify variant.

required
step int

Current step number.

required
lr float

Base learning rate.

required
beta2 float

Beta2 parameter from optimizer (momentum term).

required
n_sma_threshold float

Simple Moving Average (SMA) threshold for rectification.

required
degenerated_to_sgd bool

Whether to degenerate to SGD if below threshold.

required
Source code in pytorch_optimizer/base/optimizer.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
@staticmethod
def get_rectify_step_size(
    is_rectify: bool,
    step: int,
    lr: float,
    beta2: float,
    n_sma_threshold: int,
    degenerated_to_sgd: bool,
) -> Tuple[float, float]:
    """Get step size for rectify optimizer.

    Args:
        is_rectify (bool): Whether to apply the rectify variant.
        step (int): Current step number.
        lr (float): Base learning rate.
        beta2 (float): Beta2 parameter from optimizer (momentum term).
        n_sma_threshold (float): Simple Moving Average (SMA) threshold for rectification.
        degenerated_to_sgd (bool): Whether to degenerate to SGD if below threshold.

    """
    step_size: float = lr
    n_sma: float = 0.0

    if is_rectify:
        n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
        beta2_t: float = beta2 ** step  # fmt: skip
        n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

        if n_sma >= n_sma_threshold:
            rt = math.sqrt(
                (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
            )
        elif degenerated_to_sgd:
            rt = 1.0
        else:
            rt = -1.0

        step_size *= rt

    return step_size, n_sma

get_rms(x) staticmethod

Get RMS.

Source code in pytorch_optimizer/base/optimizer.py
311
312
313
314
315
316
317
318
319
320
321
@staticmethod
def get_rms(x: Union[List[torch.Tensor], torch.Tensor]) -> Union[List[torch.Tensor], torch.Tensor]:
    """Get RMS."""
    if isinstance(x, torch.Tensor):
        return x.norm(2).div_(math.sqrt(x.numel()))

    factors: List[float] = [math.sqrt(p.numel()) for p in x]
    norms = torch._foreach_norm(x, ord=2)
    torch._foreach_div_(norms, factors)

    return norms  # pyright: ignore[reportReturnType]

get_stable_adamw_rms(grad, exp_avg_sq, eps=1e-16) staticmethod

Get StableAdamW RMS.

Parameters:

Name Type Description Default
grad Tensor

gradient.

required
exp_avg_sq Tensor

Exponential moving average of squared gradient.

required
eps float

Small value to prevent division by zero.

1e-16
Source code in pytorch_optimizer/base/optimizer.py
468
469
470
471
472
473
474
475
476
477
478
@staticmethod
def get_stable_adamw_rms(grad: torch.Tensor, exp_avg_sq: torch.Tensor, eps: float = 1e-16) -> float:
    """Get StableAdamW RMS.

    Args:
        grad (torch.Tensor): gradient.
        exp_avg_sq (torch.Tensor): Exponential moving average of squared gradient.
        eps (float): Small value to prevent division by zero.

    """
    return grad.pow(2).div_(exp_avg_sq.clip(min=eps)).mean().sqrt_().clip_(min=1.0).item()

init_group(group, **kwargs) abstractmethod

Initialize the group of the optimizer and return is_complex.

Source code in pytorch_optimizer/base/optimizer.py
559
560
561
562
@abstractmethod
def init_group(self, group: ParamGroup, **kwargs) -> None:  # pragma: no cover
    """Initialize the group of the optimizer and return is_complex."""
    return

load_optimizer(optimizer, **kwargs) staticmethod

Build torch.optim.Optimizer class.

Source code in pytorch_optimizer/base/optimizer.py
29
30
31
32
33
34
35
36
37
38
39
@staticmethod
def load_optimizer(optimizer: OptimizerInstanceOrClass, **kwargs) -> Optimizer:
    """Build torch.optim.Optimizer class."""
    if isinstance(optimizer, Optimizer):
        return optimizer

    if 'params' in kwargs:
        params = kwargs.pop('params')
        return optimizer(params, **kwargs)

    raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')

maximize_gradient(grad, maximize=False) staticmethod

Maximize the objective with respect to the params, instead of minimizing.

Source code in pytorch_optimizer/base/optimizer.py
576
577
578
579
580
@staticmethod
def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None:
    """Maximize the objective with respect to the params, instead of minimizing."""
    if maximize:
        grad.neg_()

set_hessian(param_groups, state, hessian) staticmethod

Set hessian to state from external source. Generally useful when using functorch as a base.

Parameters:

Name Type Description Default
param_groups ParamsT

PARAMETERS. Parameter groups from optimizer.

required
state State

STATE. Optimizer state dictionary.

required
hessian List[Tensor]

List[torch.Tensor]. Sequence of Hessian tensors to set.

required
Example

Hutchinson's Estimator using Hessian-vector product (HVP)

noise = tree_map(lambda v: torch.randn_like(v), params) loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,)) hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)

optimizer.set_hessian(hessian_diag_est)

OR

optimizer.step(hessian=hessian_diag_est)

Source code in pytorch_optimizer/base/optimizer.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@staticmethod
@torch.no_grad()
def set_hessian(param_groups: ParamsT, state: State, hessian: List[torch.Tensor]) -> None:
    """Set hessian to state from external source. Generally useful when using functorch as a base.

    Args:
        param_groups: PARAMETERS. Parameter groups from optimizer.
        state: STATE. Optimizer state dictionary.
        hessian: List[torch.Tensor]. Sequence of Hessian tensors to set.

    Example:
        # Hutchinson's Estimator using Hessian-vector product (HVP)
        >>> noise = tree_map(lambda v: torch.randn_like(v), params)
        >>> loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
        >>> hessian_diag_est = tree_map(lambda a, b: a * b, hvp_est, noise)

        >>> optimizer.set_hessian(hessian_diag_est)
        # OR
        >>> optimizer.step(hessian=hessian_diag_est)

    """
    i: int = 0
    for group in param_groups or []:
        for p in group['params']:
            if p.size() != hessian[i].size():
                raise ValueError(
                    f'the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                )

            state[p]['hessian'] = hessian[i]
            i += 1

view_as_real(param, *state_and_grads) staticmethod

View imaginary tensors as real tensors.

Source code in pytorch_optimizer/base/optimizer.py
564
565
566
567
568
569
570
571
572
573
574
@staticmethod
def view_as_real(param, *state_and_grads) -> tuple:
    """View imaginary tensors as real tensors."""
    if torch.is_complex(param):
        param = torch.view_as_real(param)
        state_and_grads = tuple(
            torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None
            for s in state_and_grads
        )

    return param, *state_and_grads

zero_hessian(param_groups, state, pre_zero=True) staticmethod

Zero-out Hessian.

Parameters:

Name Type Description Default
param_groups ParamsT

Parameter groups from the optimizer.

required
state State

Optimizer state dictionary.

required
pre_zero bool

If True, zero-out the Hessian before computing/updating it.

True
Source code in pytorch_optimizer/base/optimizer.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@staticmethod
def zero_hessian(param_groups: ParamsT, state: State, pre_zero: bool = True) -> None:
    """Zero-out Hessian.

    Args:
        param_groups (ParamsT): Parameter groups from the optimizer.
        state (State): Optimizer state dictionary.
        pre_zero (bool): If True, zero-out the Hessian before computing/updating it.

    """
    for group in param_groups or []:
        for p in group['params']:
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                if 'hessian' not in state[p]:
                    state[p]['hessian'] = torch.zeros_like(p)
                elif pre_zero:
                    state[p]['hessian'].zero_()