Skip to content

Base

BaseOptimizer

Bases: ABC

Base optimizer class.

Source code in pytorch_optimizer/base/optimizer.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class BaseOptimizer(ABC):
    r"""Base optimizer class."""

    @staticmethod
    @torch.no_grad()
    def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
        r"""Set hessian to state from external source. Generally useful when using functorch as a base.

        Example:
        -------
            Here's an example::

                # Hutchinson's Estimator using HVP
                noise = tree_map(lambda v: torch.randn_like(v), params)
                loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
                hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

                optimizer.set_hessian(hessian_diag_est)
                # OR
                optimizer.step(hessian=hessian_diag_est)

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param hessian: List[torch.Tensor]. sequence of hessian to set.
        """
        i: int = 0
        for group in param_groups:
            for p in group['params']:
                if p.size() != hessian[i].size():
                    raise ValueError(
                        f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                    )

                state[p]['hessian'] = hessian[i]
                i += 1

    @staticmethod
    def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True):
        r"""Zero-out hessian.

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param pre_zero: bool. zero-out hessian before computing the hessian.
        """
        for group in param_groups:
            for p in group['params']:
                if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                    if 'hessian' not in state[p]:
                        state[p]['hessian'] = torch.zeros_like(p)
                    elif pre_zero:
                        state[p]['hessian'].zero_()

    @staticmethod
    @torch.no_grad()
    def compute_hutchinson_hessian(
        param_groups: PARAMETERS,
        state: STATE,
        num_samples: int = 1,
        alpha: float = 1.0,
        distribution: HUTCHINSON_G = 'gaussian',
    ):
        r"""Hutchinson's approximate hessian, added to the state under key `hessian`.

        :param param_groups: PARAMETERS. parameter groups.
        :param state: STATE. optimizer state.
        :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
        :param alpha: float. alpha.
        :param distribution: HUTCHINSON_G. type of distribution.
        """
        if distribution not in ('gaussian', 'rademacher'):
            raise NotImplementedError(f'[-] Hessian with distribution {distribution} is not implemented.')

        params: List[torch.Tensor] = [
            p
            for group in param_groups
            for p in group['params']
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse
        ]
        if len(params) == 0:
            return

        grads = [p.grad for p in params]

        for i in range(num_samples):
            if distribution == 'rademacher':
                zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
            else:
                zs = [torch.randn_like(p) for p in params]

            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
            for h_z, z, p in zip(h_zs, zs, params):
                state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

    @staticmethod
    def apply_weight_decay(
        p: torch.Tensor,
        grad: Optional[torch.Tensor],
        lr: float,
        weight_decay: float,
        weight_decouple: bool,
        fixed_decay: bool,
        ratio: Optional[float] = None,
    ):
        r"""Apply weight decay.

        :param p: torch.Tensor. parameter.
        :param grad: torch.Tensor. gradient.
        :param lr: float. learning rate.
        :param weight_decay: float. weight decay (L2 penalty).
        :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
        :param fixed_decay: bool. fix weight decay.
        :param ratio: Optional[float]. scale weight decay.
        """
        if weight_decouple:
            p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
        elif weight_decay > 0.0 and grad is not None:
            grad.add_(p, alpha=weight_decay)

    @staticmethod
    def apply_ams_bound(
        ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
    ) -> torch.Tensor:
        r"""Apply AMSBound variant.

        :param ams_bound: bool. whether to apply AMSBound.
        :param exp_avg_sq: torch.Tensor. exp_avg_sq.
        :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
        :param eps: float. epsilon.
        """
        if ams_bound:
            torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
            de_nom = max_exp_avg_sq.add(eps)
        else:
            de_nom = exp_avg_sq.add(eps)

        return de_nom.sqrt_().add_(eps)

    @staticmethod
    def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
        r"""Apply AdamD variant.

        :param adam_debias: bool. whether to apply AdamD.
        :param step_size: float. step size.
        :param bias_correction1: float. bias_correction.
        """
        return step_size if adam_debias else step_size / bias_correction1

    @staticmethod
    def get_rectify_step_size(
        is_rectify: bool,
        step: int,
        lr: float,
        beta2: float,
        n_sma_threshold: int,
        degenerated_to_sgd: bool,
    ) -> Tuple[float, float]:
        r"""Get step size for rectify optimizer.

        :param is_rectify: bool. whether to apply rectify-variant.
        :param step: int. number of steps.
        :param lr: float. learning rate.
        :param beta2: float. beta2.
        :param n_sma_threshold: float. SMA threshold.
        :param degenerated_to_sgd: bool. degenerated to SGD.
        """
        step_size: float = lr
        n_sma: float = 0.0

        if is_rectify:
            n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
            beta2_t: float = beta2 ** step  # fmt: skip
            n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

            if n_sma >= n_sma_threshold:
                rt = math.sqrt(
                    (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
                )
            elif degenerated_to_sgd:
                rt = 1.0
            else:
                rt = -1.0

            step_size *= rt

        return step_size, n_sma

    @staticmethod
    def get_adanorm_gradient(
        grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
    ) -> torch.Tensor:
        r"""Get AdaNorm gradient.

        :param grad: torch.Tensor. gradient.
        :param adanorm: bool. whether to apply AdaNorm.
        :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
        :param r: float. Optional[float]. momentum (ratio).
        """
        if not adanorm:
            return grad

        grad_norm = torch.linalg.norm(grad)

        exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)

        return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad

    @staticmethod
    def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)'):
        if range_type == '[)' and not low <= x < high:
            raise ValueError(f'[-] {name} must be in the range [{low}, {high})')
        if range_type == '[]' and not low <= x <= high:
            raise ValueError(f'[-] {name} must be in the range [{low}, {high}]')
        if range_type == '(]' and not low < x <= high:
            raise ValueError(f'[-] {name} must be in the range ({low}, {high}]')
        if range_type == '()' and not low < x < high:
            raise ValueError(f'[-] {name} must be in the range ({low}, {high})')

    @staticmethod
    def validate_non_negative(x: Optional[float], name: str):
        if x is not None and x < 0.0:
            raise ValueError(f'[-] {name} must be non-negative')

    @staticmethod
    def validate_positive(x: Union[float, int], name: str):
        if x <= 0:
            raise ValueError(f'[-] {name} must be positive')

    @staticmethod
    def validate_boundary(constant: float, boundary: float, bound_type: str = 'upper'):
        if bound_type == 'upper' and constant > boundary:
            raise ValueError(f'[-] constant {constant} must be in a range of (-inf, {boundary}]')
        if bound_type == 'lower' and constant < boundary:
            raise ValueError(f'[-] constant {constant} must be in a range of [{boundary}, inf)')

    @staticmethod
    def validate_step(step: int, step_type: str):
        if step < 1:
            raise NegativeStepError(step, step_type=step_type)

    @staticmethod
    def validate_options(x: str, name: str, options: List[str]):
        if x not in options:
            opts: str = ' or '.join([f'\'{option}\'' for option in options]).strip()
            raise ValueError(f'[-] {name} {x} must be one of ({opts})')

    @staticmethod
    def validate_learning_rate(learning_rate: Optional[float]):
        if learning_rate is not None and learning_rate < 0.0:
            raise NegativeLRError(learning_rate)

    def validate_betas(self, betas: BETAS):
        self.validate_range(betas[0], 'beta1', 0.0, 1.0, range_type='[]')
        self.validate_range(betas[1], 'beta2', 0.0, 1.0, range_type='[]')

        if len(betas) < 3:
            return

        if betas[2] is not None:
            self.validate_range(betas[2], 'beta3', 0.0, 1.0, range_type='[]')

    def validate_nus(self, nus: Union[float, Tuple[float, float]]):
        if isinstance(nus, float):
            self.validate_range(nus, 'nu', 0.0, 1.0, range_type='[]')
        else:
            self.validate_range(nus[0], 'nu1', 0.0, 1.0, range_type='[]')
            self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]')

    @abstractmethod
    def reset(self):  # pragma: no cover
        raise NotImplementedError

apply_adam_debias(adam_debias, step_size, bias_correction1) staticmethod

Apply AdamD variant.

Parameters:

Name Type Description Default
adam_debias bool

bool. whether to apply AdamD.

required
step_size float

float. step size.

required
bias_correction1 float

float. bias_correction.

required
Source code in pytorch_optimizer/base/optimizer.py
148
149
150
151
152
153
154
155
156
@staticmethod
def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float:
    r"""Apply AdamD variant.

    :param adam_debias: bool. whether to apply AdamD.
    :param step_size: float. step size.
    :param bias_correction1: float. bias_correction.
    """
    return step_size if adam_debias else step_size / bias_correction1

apply_ams_bound(ams_bound, exp_avg_sq, max_exp_avg_sq, eps) staticmethod

Apply AMSBound variant.

Parameters:

Name Type Description Default
ams_bound bool

bool. whether to apply AMSBound.

required
exp_avg_sq Tensor

torch.Tensor. exp_avg_sq.

required
max_exp_avg_sq Optional[Tensor]

Optional[torch.Tensor]. max_exp_avg_sq.

required
eps float

float. epsilon.

required
Source code in pytorch_optimizer/base/optimizer.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@staticmethod
def apply_ams_bound(
    ams_bound: bool, exp_avg_sq: torch.Tensor, max_exp_avg_sq: Optional[torch.Tensor], eps: float
) -> torch.Tensor:
    r"""Apply AMSBound variant.

    :param ams_bound: bool. whether to apply AMSBound.
    :param exp_avg_sq: torch.Tensor. exp_avg_sq.
    :param max_exp_avg_sq: Optional[torch.Tensor]. max_exp_avg_sq.
    :param eps: float. epsilon.
    """
    if ams_bound:
        torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
        de_nom = max_exp_avg_sq.add(eps)
    else:
        de_nom = exp_avg_sq.add(eps)

    return de_nom.sqrt_().add_(eps)

apply_weight_decay(p, grad, lr, weight_decay, weight_decouple, fixed_decay, ratio=None) staticmethod

Apply weight decay.

Parameters:

Name Type Description Default
p Tensor

torch.Tensor. parameter.

required
grad Optional[Tensor]

torch.Tensor. gradient.

required
lr float

float. learning rate.

required
weight_decay float

float. weight decay (L2 penalty).

required
weight_decouple bool

bool. the optimizer uses decoupled weight decay as in AdamW.

required
fixed_decay bool

bool. fix weight decay.

required
ratio Optional[float]

Optional[float]. scale weight decay.

None
Source code in pytorch_optimizer/base/optimizer.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@staticmethod
def apply_weight_decay(
    p: torch.Tensor,
    grad: Optional[torch.Tensor],
    lr: float,
    weight_decay: float,
    weight_decouple: bool,
    fixed_decay: bool,
    ratio: Optional[float] = None,
):
    r"""Apply weight decay.

    :param p: torch.Tensor. parameter.
    :param grad: torch.Tensor. gradient.
    :param lr: float. learning rate.
    :param weight_decay: float. weight decay (L2 penalty).
    :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
    :param fixed_decay: bool. fix weight decay.
    :param ratio: Optional[float]. scale weight decay.
    """
    if weight_decouple:
        p.mul_(1.0 - weight_decay * (1.0 if fixed_decay else lr) * (ratio if ratio is not None else 1.0))
    elif weight_decay > 0.0 and grad is not None:
        grad.add_(p, alpha=weight_decay)

compute_hutchinson_hessian(param_groups, state, num_samples=1, alpha=1.0, distribution='gaussian') staticmethod

Hutchinson's approximate hessian, added to the state under key hessian.

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
num_samples int

int. number of times to sample z for the approximation of the hessian trace.

1
alpha float

float. alpha.

1.0
distribution HUTCHINSON_G

HUTCHINSON_G. type of distribution.

'gaussian'
Source code in pytorch_optimizer/base/optimizer.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
@staticmethod
@torch.no_grad()
def compute_hutchinson_hessian(
    param_groups: PARAMETERS,
    state: STATE,
    num_samples: int = 1,
    alpha: float = 1.0,
    distribution: HUTCHINSON_G = 'gaussian',
):
    r"""Hutchinson's approximate hessian, added to the state under key `hessian`.

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param num_samples: int. number of times to sample `z` for the approximation of the hessian trace.
    :param alpha: float. alpha.
    :param distribution: HUTCHINSON_G. type of distribution.
    """
    if distribution not in ('gaussian', 'rademacher'):
        raise NotImplementedError(f'[-] Hessian with distribution {distribution} is not implemented.')

    params: List[torch.Tensor] = [
        p
        for group in param_groups
        for p in group['params']
        if p.requires_grad and p.grad is not None and not p.grad.is_sparse
    ]
    if len(params) == 0:
        return

    grads = [p.grad for p in params]

    for i in range(num_samples):
        if distribution == 'rademacher':
            zs = [torch.randint_like(p, 0, 1) * 2.0 - 1.0 for p in params]
        else:
            zs = [torch.randn_like(p) for p in params]

        h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, retain_graph=i < num_samples - 1)
        for h_z, z, p in zip(h_zs, zs, params):
            state[p]['hessian'].add_(h_z * z, alpha=alpha / num_samples)

get_adanorm_gradient(grad, adanorm, exp_grad_norm=None, r=0.95) staticmethod

Get AdaNorm gradient.

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. gradient.

required
adanorm bool

bool. whether to apply AdaNorm.

required
exp_grad_norm Optional[Tensor]

Optional[torch.Tensor]. exp_grad_norm.

None
r Optional[float]

float. Optional[float]. momentum (ratio).

0.95
Source code in pytorch_optimizer/base/optimizer.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@staticmethod
def get_adanorm_gradient(
    grad: torch.Tensor, adanorm: bool, exp_grad_norm: Optional[torch.Tensor] = None, r: Optional[float] = 0.95
) -> torch.Tensor:
    r"""Get AdaNorm gradient.

    :param grad: torch.Tensor. gradient.
    :param adanorm: bool. whether to apply AdaNorm.
    :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm.
    :param r: float. Optional[float]. momentum (ratio).
    """
    if not adanorm:
        return grad

    grad_norm = torch.linalg.norm(grad)

    exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r)

    return grad * exp_grad_norm / grad_norm if exp_grad_norm > grad_norm else grad

get_rectify_step_size(is_rectify, step, lr, beta2, n_sma_threshold, degenerated_to_sgd) staticmethod

Get step size for rectify optimizer.

Parameters:

Name Type Description Default
is_rectify bool

bool. whether to apply rectify-variant.

required
step int

int. number of steps.

required
lr float

float. learning rate.

required
beta2 float

float. beta2.

required
n_sma_threshold int

float. SMA threshold.

required
degenerated_to_sgd bool

bool. degenerated to SGD.

required
Source code in pytorch_optimizer/base/optimizer.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@staticmethod
def get_rectify_step_size(
    is_rectify: bool,
    step: int,
    lr: float,
    beta2: float,
    n_sma_threshold: int,
    degenerated_to_sgd: bool,
) -> Tuple[float, float]:
    r"""Get step size for rectify optimizer.

    :param is_rectify: bool. whether to apply rectify-variant.
    :param step: int. number of steps.
    :param lr: float. learning rate.
    :param beta2: float. beta2.
    :param n_sma_threshold: float. SMA threshold.
    :param degenerated_to_sgd: bool. degenerated to SGD.
    """
    step_size: float = lr
    n_sma: float = 0.0

    if is_rectify:
        n_sma_max: float = 2.0 / (1.0 - beta2) - 1.0
        beta2_t: float = beta2 ** step  # fmt: skip
        n_sma: float = n_sma_max - 2 * step * beta2_t / (1.0 - beta2_t)

        if n_sma >= n_sma_threshold:
            rt = math.sqrt(
                (1.0 - beta2_t) * (n_sma - 4) / (n_sma_max - 4) * (n_sma - 2) / n_sma * n_sma_max / (n_sma_max - 2)
            )
        elif degenerated_to_sgd:
            rt = 1.0
        else:
            rt = -1.0

        step_size *= rt

    return step_size, n_sma

set_hessian(param_groups, state, hessian) staticmethod

Set hessian to state from external source. Generally useful when using functorch as a base.

Example:
Here's an example::

    # Hutchinson's Estimator using HVP
    noise = tree_map(lambda v: torch.randn_like(v), params)
    loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
    hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

    optimizer.set_hessian(hessian_diag_est)
    # OR
    optimizer.step(hessian=hessian_diag_est)

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
hessian List[Tensor]

List[torch.Tensor]. sequence of hessian to set.

required
Source code in pytorch_optimizer/base/optimizer.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@staticmethod
@torch.no_grad()
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]):
    r"""Set hessian to state from external source. Generally useful when using functorch as a base.

    Example:
    -------
        Here's an example::

            # Hutchinson's Estimator using HVP
            noise = tree_map(lambda v: torch.randn_like(v), params)
            loss_, hvp_est = jvp(grad(run_model_fn), (params,), (noise,))
            hessian_diag_est  = tree_map(lambda a, b: a * b, hvp_est, noise)

            optimizer.set_hessian(hessian_diag_est)
            # OR
            optimizer.step(hessian=hessian_diag_est)

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param hessian: List[torch.Tensor]. sequence of hessian to set.
    """
    i: int = 0
    for group in param_groups:
        for p in group['params']:
            if p.size() != hessian[i].size():
                raise ValueError(
                    f'[-] the shape of parameter and hessian does not match. {p.size()} vs {hessian[i].size()}'
                )

            state[p]['hessian'] = hessian[i]
            i += 1

zero_hessian(param_groups, state, pre_zero=True) staticmethod

Zero-out hessian.

Parameters:

Name Type Description Default
param_groups PARAMETERS

PARAMETERS. parameter groups.

required
state STATE

STATE. optimizer state.

required
pre_zero bool

bool. zero-out hessian before computing the hessian.

True
Source code in pytorch_optimizer/base/optimizer.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@staticmethod
def zero_hessian(param_groups: PARAMETERS, state: STATE, pre_zero: bool = True):
    r"""Zero-out hessian.

    :param param_groups: PARAMETERS. parameter groups.
    :param state: STATE. optimizer state.
    :param pre_zero: bool. zero-out hessian before computing the hessian.
    """
    for group in param_groups:
        for p in group['params']:
            if p.requires_grad and p.grad is not None and not p.grad.is_sparse:
                if 'hessian' not in state[p]:
                    state[p]['hessian'] = torch.zeros_like(p)
                elif pre_zero:
                    state[p]['hessian'].zero_()