Skip to content

Utilization

get_optimizer_parameters(model_or_parameter, weight_decay, wd_ban_list=('bias', 'LayerNorm.bias', 'LayerNorm.weight'))

Get optimizer parameters while filtering specified modules.

Parameters:

Name Type Description Default
model_or_parameter Union[Module, List]

Union[nn.Module, List]. model or parameters.

required
weight_decay float

float. weight_decay.

required
wd_ban_list List[str]

List[str]. ban list not to set weight decay.

('bias', 'LayerNorm.bias', 'LayerNorm.weight')

Returns:

Type Description
PARAMETERS

PARAMETERS. new parameter list.

Source code in pytorch_optimizer/optimizer/utils.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def get_optimizer_parameters(
    model_or_parameter: Union[nn.Module, List],
    weight_decay: float,
    wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
) -> PARAMETERS:
    r"""Get optimizer parameters while filtering specified modules.

    :param model_or_parameter: Union[nn.Module, List]. model or parameters.
    :param weight_decay: float. weight_decay.
    :param wd_ban_list: List[str]. ban list not to set weight decay.
    :returns: PARAMETERS. new parameter list.
    """
    if isinstance(model_or_parameter, nn.Module):
        model_or_parameter = list(model_or_parameter.named_parameters())

    return [
        {
            'params': [p for n, p in model_or_parameter if p.requires_grad and not any(nd in n for nd in wd_ban_list)],
            'weight_decay': weight_decay,
        },
        {
            'params': [p for n, p in model_or_parameter if p.requires_grad and any(nd in n for nd in wd_ban_list)],
            'weight_decay': 0.0,
        },
    ]

is_valid_parameters(parameters)

Check where the parameters are valid.

Source code in pytorch_optimizer/optimizer/utils.py
15
16
17
def is_valid_parameters(parameters: PARAMETERS) -> bool:
    r"""Check where the parameters are valid."""
    return isinstance(parameters, (list, tuple)) and len(parameters) > 0 and isinstance(parameters[0], dict)

has_overflow(grad_norm)

Detect inf and NaN in grad_norm.

Source code in pytorch_optimizer/optimizer/utils.py
20
21
22
def has_overflow(grad_norm: torch.Tensor) -> bool:
    r"""Detect inf and NaN in grad_norm."""
    return bool(torch.logical_or(torch.isnan(grad_norm), torch.isinf(grad_norm)).any())

to_real(x)

Return real value of tensor.

Source code in pytorch_optimizer/optimizer/utils.py
25
26
27
def to_real(x: torch.Tensor) -> torch.Tensor:
    r"""Return real value of tensor."""
    return x.real if torch.is_complex(x) else x

normalize_gradient(x, use_channels=False, epsilon=1e-08)

Normalize gradient with stddev.

Parameters:

Name Type Description Default
x Tensor

torch.Tensor. gradient.

required
use_channels bool

bool. channel-wise normalization.

False
epsilon float

float. eps.

1e-08
Source code in pytorch_optimizer/optimizer/utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: float = 1e-8):
    r"""Normalize gradient with stddev.

    :param x: torch.Tensor. gradient.
    :param use_channels: bool. channel-wise normalization.
    :param epsilon: float. eps.
    """
    size: int = x.dim()
    if size > 1 and use_channels:
        s = x.std(dim=tuple(range(1, size)), keepdim=True).add_(epsilon)
        x.div_(s)
    elif torch.numel(x) > 2:
        s = x.std().add_(epsilon)
        x.div_(s)

flatten_grad(grads)

Flatten the gradient.

Source code in pytorch_optimizer/optimizer/utils.py
46
47
48
def flatten_grad(grads: List[torch.Tensor]) -> torch.Tensor:
    r"""Flatten the gradient."""
    return torch.cat([grad.flatten() for grad in grads])

un_flatten_grad(grads, shapes)

Unflatten the gradient.

Source code in pytorch_optimizer/optimizer/utils.py
51
52
53
54
55
56
57
58
59
def un_flatten_grad(grads: torch.Tensor, shapes: List[int]) -> List[torch.Tensor]:
    r"""Unflatten the gradient."""
    idx: int = 0
    un_flatten_grads: List[torch.Tensor] = []
    for shape in shapes:
        length = np.prod(shape)
        un_flatten_grads.append(grads[idx:idx + length].view(shape).clone())  # fmt: skip
        idx += length
    return un_flatten_grads

channel_view(x)

Do channel view.

Source code in pytorch_optimizer/optimizer/utils.py
62
63
64
def channel_view(x: torch.Tensor) -> torch.Tensor:
    r"""Do channel view."""
    return x.view(x.size()[0], -1)

layer_view(x)

Do layer view.

Source code in pytorch_optimizer/optimizer/utils.py
67
68
69
def layer_view(x: torch.Tensor) -> torch.Tensor:
    r"""Do layer view."""
    return x.view(1, -1)

cosine_similarity_by_view(x, y, eps, view_func)

Calculate cosine similarity by the view.

Parameters:

Name Type Description Default
x Tensor

torch.Tensor. src.

required
y Tensor

torch.Tensor. dst.

required
eps float

float. epsilon.

required
view_func Callable[[Tensor], Tensor]

Callable. view (channel or layer) function.

required
Source code in pytorch_optimizer/optimizer/utils.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def cosine_similarity_by_view(
    x: torch.Tensor,
    y: torch.Tensor,
    eps: float,
    view_func: Callable[[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
    r"""Calculate cosine similarity by the view.

    :param x: torch.Tensor. src.
    :param y: torch.Tensor. dst.
    :param eps: float. epsilon.
    :param view_func: Callable. view (channel or layer) function.
    """
    x = view_func(x)
    y = view_func(y)
    return f.cosine_similarity(x, y, dim=1, eps=eps).abs_()

clip_grad_norm(parameters, max_norm=0.0, sync=False)

Clip gradient norms.

During combination with FSDP, will also ensure that grad norms are aggregated across all workers,
since each worker only stores their shard of the gradients.

Parameters:

Name Type Description Default
parameters PARAMETERS

PARAMETERS. Parameters whose gradients we wish to clip.

required
max_norm float

float. Maximum norm we wish the gradients to have. If non-positive, then we will not perform clipping.

0.0
sync bool

bool. Boolean indicating whether we should aggregate across the distributed group. Used only in combination with FSDP.

False

Returns:

Type Description
Union[Tensor, float]

The gradient norm across all parameters, before clipping.

Source code in pytorch_optimizer/optimizer/utils.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def clip_grad_norm(
    parameters: PARAMETERS,
    max_norm: float = 0.0,
    sync: bool = False,
) -> Union[torch.Tensor, float]:  # pragma: no cover
    r"""Clip gradient norms.

        During combination with FSDP, will also ensure that grad norms are aggregated across all workers,
        since each worker only stores their shard of the gradients.

    :param parameters: PARAMETERS. Parameters whose gradients we wish to clip.
    :param max_norm: float. Maximum norm we wish the gradients to have. If non-positive, then we will not perform
        clipping.
    :param sync: bool. Boolean indicating whether we should aggregate across the distributed group. Used only in
        combination with FSDP.
    :returns: The gradient norm across all parameters, before clipping.
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # make sure any generators are expanded
    parameters = list(parameters)

    # if syncing we need to manually perform the clipping so that we aggregate properly
    if max_norm > 0 and not sync:
        return clip_grad_norm_(parameters, max_norm)

    norm_sq = sum(p.grad.norm() ** 2 for p in parameters if p.grad is not None)
    if sync:
        # also need to get the norms from all the other sharded works in FSDP
        all_reduce(norm_sq)

    grad_norm = math.sqrt(norm_sq)
    if max_norm > 0:
        clip_coefficient = max_norm / (grad_norm + 1e-6)
        for p in parameters:
            p.grad.detach().mul_(clip_coefficient)

    return grad_norm

projection(p, grad, perturb, delta, wd_ratio, eps)

Project to remove the radial component from the update vector.

Source code in pytorch_optimizer/optimizer/utils.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def projection(
    p: torch.Tensor,
    grad: torch.Tensor,
    perturb: torch.Tensor,
    delta: float,
    wd_ratio: float,
    eps: float,
) -> Tuple[torch.Tensor, float]:
    r"""Project to remove the radial component from the update vector."""
    wd: float = 1.0
    expand_size: List[int] = [-1] + [1] * (len(p.shape) - 1)
    for view_func in (channel_view, layer_view):
        cosine_sim = cosine_similarity_by_view(grad, p, eps, view_func)

        if cosine_sim.max() < delta / math.sqrt(view_func(p).size()[1]):
            p_n = p / view_func(p).norm(dim=1).view(expand_size).add_(eps)
            perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
            wd = wd_ratio
            return perturb, wd

    return perturb, wd

unit_norm(x, norm=2.0)

Get norm of unit.

Source code in pytorch_optimizer/optimizer/utils.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
    r"""Get norm of unit."""
    keep_dim: bool = True
    dim: Optional[Union[int, Tuple[int, ...]]] = None

    x_len: int = len(x.shape)
    if x_len <= 1:
        keep_dim = False
    elif x_len in (2, 3):  # linear layers
        dim = 1
    elif x_len == 4:  # conv kernels
        dim = (1, 2, 3)
    else:
        dim = tuple(range(1, x_len))

    return x.norm(p=norm, dim=dim, keepdim=keep_dim)

neuron_norm(x)

Get norm of the tensor.

Source code in pytorch_optimizer/optimizer/utils.py
199
200
201
202
203
204
205
206
def neuron_norm(x: torch.Tensor) -> torch.Tensor:
    r"""Get norm of the tensor."""
    if x.dim() <= 1:
        return x.abs()

    view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1)

    return channel_view(x).norm(dim=1).view(*view_shape)

neuron_mean(x)

Get mean of the tensor.

Source code in pytorch_optimizer/optimizer/utils.py
209
210
211
212
213
214
215
216
def neuron_mean(x: torch.Tensor) -> torch.Tensor:
    r"""Get mean of the tensor."""
    if x.dim() <= 1:
        raise ValueError('[-] neuron_mean not defined on 1D tensors.')

    view_shape: List[int] = [x.shape[0]] + [1] * (x.dim() - 1)

    return channel_view(x).mean(dim=1).view(*view_shape)

disable_running_stats(model)

Disable running stats (momentum) of BatchNorm.

Source code in pytorch_optimizer/optimizer/utils.py
219
220
221
222
223
224
225
226
227
def disable_running_stats(model):
    r"""Disable running stats (momentum) of BatchNorm."""

    def _disable(module):
        if isinstance(module, _BatchNorm):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

enable_running_stats(model)

Enable running stats (momentum) of BatchNorm.

Source code in pytorch_optimizer/optimizer/utils.py
230
231
232
233
234
235
236
237
def enable_running_stats(model):
    r"""Enable running stats (momentum) of BatchNorm."""

    def _enable(module):
        if isinstance(module, _BatchNorm) and hasattr(module, 'backup_momentum'):
            module.momentum = module.backup_momentum

    model.apply(_enable)

l2_projection(parameters, max_norm=100.0)

Get l2 normalized parameter.

Source code in pytorch_optimizer/optimizer/utils.py
240
241
242
243
244
245
246
247
@torch.no_grad()
def l2_projection(parameters: PARAMETERS, max_norm: float = 1e2):
    r"""Get l2 normalized parameter."""
    global_norm = torch.sqrt(sum(p.norm().pow(2) for p in parameters))
    if global_norm > max_norm:
        ratio = max_norm / global_norm
        for param in parameters:
            param.mul_(ratio)

get_global_gradient_norm(param_groups, device)

Get global gradient norm.

Source code in pytorch_optimizer/optimizer/utils.py
250
251
252
253
254
255
256
257
258
259
260
@torch.no_grad()
def get_global_gradient_norm(param_groups: List[Dict], device: torch.device) -> torch.Tensor:
    r"""Get global gradient norm."""
    global_grad_norm = torch.zeros(1, dtype=torch.float32, device=device)

    for group in param_groups:
        for p in group['params']:
            if p.grad is not None:
                global_grad_norm.add_(p.grad.norm().pow(2))

    return global_grad_norm

reduce_max_except_dim(x, dim)

Perform reduce-max along all dimensions except the given dim.

Parameters:

Name Type Description Default
x Tensor

torch.Tensor. tensor to reduce-max.

required
dim int

int. dimension to exclude.

required
Source code in pytorch_optimizer/optimizer/utils.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
@torch.no_grad()
def reduce_max_except_dim(x: torch.Tensor, dim: int) -> torch.Tensor:
    r"""Perform reduce-max along all dimensions except the given dim.

    :param x: torch.Tensor. tensor to reduce-max.
    :param dim: int. dimension to exclude.
    """
    rank: int = len(x.shape)
    if rank == 0:
        return x

    if dim >= rank:
        raise ValueError(f'[-] given dim is bigger than rank. {dim} >= {rank}')

    for d in range(rank):
        if d != dim:
            x = x.max(dim=d, keepdim=True).values
    return x

merge_small_dims(shape_to_merge, max_dim)

Merge small dimensions.

If there are some small dimensions, we collapse them
    e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
    [1, 2, 768, 1, 2048] --> [2, 768, 2048].

Parameters:

Name Type Description Default
shape_to_merge List[int]

List[int]. Shape to merge small dimensions.

required
max_dim int

int. Maximal dimension of output shape used in merging.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
def merge_small_dims(shape_to_merge: List[int], max_dim: int) -> List[int]:
    r"""Merge small dimensions.

        If there are some small dimensions, we collapse them
            e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
            [1, 2, 768, 1, 2048] --> [2, 768, 2048].

    :param shape_to_merge: List[int]. Shape to merge small dimensions.
    :param max_dim: int. Maximal dimension of output shape used in merging.
    """
    merged_shape: List[int] = []

    product: int = 1
    for dim in shape_to_merge:
        product *= dim
        if product > max_dim:
            merged_shape.append(product // dim)
            product = dim

    merged_shape.append(product)

    return merged_shape if len(merged_shape) > 1 else [1]

Newton methods

power_iteration(mat_g, num_iters=100)

Compute the maximum eigenvalue of matrix, for scaling.

Mostly, power_iteration method is faster than torch.einval in case of the symmetric PSD matrix.
Also, I removed the validation, error of singular value every iteration, so that boosting the speed.

Parameters:

Name Type Description Default
mat_g Tensor

torch.Tensor. the symmetric PSD matrix.

required
num_iters int

int. Number of iterations.

100
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
@torch.no_grad()
def power_iteration(mat_g: torch.Tensor, num_iters: int = 100) -> torch.Tensor:
    r"""Compute the maximum eigenvalue of matrix, for scaling.

        Mostly, power_iteration method is faster than torch.einval in case of the symmetric PSD matrix.
        Also, I removed the validation, error of singular value every iteration, so that boosting the speed.

    :param mat_g: torch.Tensor. the symmetric PSD matrix.
    :param num_iters: int. Number of iterations.
    """
    v = torch.randn(mat_g.shape[0], dtype=mat_g.dtype, device=mat_g.device)
    mat_v = torch.empty_like(v)

    for _ in range(num_iters):
        torch.mv(mat_g, v, out=mat_v)
        v = mat_v.div(torch.linalg.norm(mat_v))

    return (v.t() @ mat_g @ v).clamp_min_(1e-16)

compute_power_schur_newton(mat_g, p, max_iters=100, error_tolerance=0.001, ridge_epsilon=1e-06, max_error_ratio=1.2)

Compute G^{-1/p} using a coupled Newton iteration.

See for example equation 3.2 on page 9 of:
    A Schur-Newton Method for the Matrix p-th Root and its Inverse by Chun-Hua Guo and Nicholas J. Higham
    SIAM Journal on Matrix Analysis and Applications, 2006, Vol. 28, No. 3 : pp. 788-804
    https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf.

The best value for z is (1 + p) * (c_max^{1/p} - c_min^{1/p}) / (c_max^{1+1/p} - c_min^{1+1/p})
where c_max and c_min are the largest and smallest singular values of mat_g.
The above estimate assumes that c_max > c_min * 2^p can replace above line by the one below,
but it is less accurate, hence needs more iterations to converge.

z = (1 + p) / tf.trace(mat_g)
If we want the method to always converge, use z = 1 / norm(mat_g) or z = 1 / tf.trace(mat_g),
but these can result in many extra iterations.

Parameters:

Name Type Description Default
mat_g Tensor

torch.Tensor. A square positive semi-definite matrix.

required
p int

int. a positive integer.

required
max_iters int

int. Stop iterating after this many rounds.

100
error_tolerance float

float. Threshold for stopping iteration.

0.001
ridge_epsilon float

float. We add this times I to G, to make is positive definite. For scaling, we multiply it by the largest eigenvalue of G.

1e-06
max_error_ratio float

float. Sometimes error increases after an iteration before decreasing and converging. 1.2 factor is used to bound the maximal allowed increase.

1.2
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
@torch.no_grad()
def compute_power_schur_newton(
    mat_g: torch.Tensor,
    p: int,
    max_iters: int = 100,
    error_tolerance: float = 1e-3,
    ridge_epsilon: float = 1e-6,
    max_error_ratio: float = 1.2,
) -> torch.Tensor:
    r"""Compute G^{-1/p} using a coupled Newton iteration.

        See for example equation 3.2 on page 9 of:
            A Schur-Newton Method for the Matrix p-th Root and its Inverse by Chun-Hua Guo and Nicholas J. Higham
            SIAM Journal on Matrix Analysis and Applications, 2006, Vol. 28, No. 3 : pp. 788-804
            https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf.

        The best value for z is (1 + p) * (c_max^{1/p} - c_min^{1/p}) / (c_max^{1+1/p} - c_min^{1+1/p})
        where c_max and c_min are the largest and smallest singular values of mat_g.
        The above estimate assumes that c_max > c_min * 2^p can replace above line by the one below,
        but it is less accurate, hence needs more iterations to converge.

        z = (1 + p) / tf.trace(mat_g)
        If we want the method to always converge, use z = 1 / norm(mat_g) or z = 1 / tf.trace(mat_g),
        but these can result in many extra iterations.

    :param mat_g: torch.Tensor. A square positive semi-definite matrix.
    :param p: int. a positive integer.
    :param max_iters: int. Stop iterating after this many rounds.
    :param error_tolerance: float. Threshold for stopping iteration.
    :param ridge_epsilon: float. We add this times I to G, to make is positive definite.
        For scaling, we multiply it by the largest eigenvalue of G.
    :param max_error_ratio: float. Sometimes error increases after an iteration before decreasing and converging.
        1.2 factor is used to bound the maximal allowed increase.
    """
    shape: List[int] = mat_g.shape
    if len(shape) == 1:
        return torch.pow(mat_g + ridge_epsilon, -1.0 / p)

    identity = torch.eye(shape[0], dtype=mat_g.dtype, device=mat_g.device)
    if shape[0] == 1:
        return identity

    mat_g += power_iteration(mat_g) * identity * ridge_epsilon

    z = (1 + p) / (2 * torch.linalg.norm(mat_g))

    mat_root = identity * torch.pow(z, 1.0 / p)

    mat_m = mat_g * z

    alpha: float = -1.0 / p
    alpha_identity = (1.0 - alpha) * identity

    prev_error = torch.max(torch.abs(mat_m - identity))

    for _ in range(max_iters):
        mat_m_i = alpha_identity + alpha * mat_m

        new_mat_root = torch.matmul(mat_root, mat_m_i)
        torch.matmul(torch.linalg.matrix_power(mat_m_i, p), mat_m, out=mat_m)

        error = torch.max(torch.abs(mat_m - identity))

        # NOTE
        # this is the main bottleneck that makes Scalable Shampoo slow.
        # because it is handled on the Python side so values need to be on the CPU
        # while XLA devices (e.g. TPU) doesn't seem to be affected.
        if torch.logical_or(error > prev_error * max_error_ratio, error <= error_tolerance):
            break

        mat_root = new_mat_root
        prev_error = error

    return mat_root

compute_power_svd(matrix, power)

Compute G^{-1/p} using a SVD.

Calculate SVD on the GPU. Sometimes, SVD on the CPU is faster than GPU, but based on the several experiments,
CUDA seems much faster than on CPU.

Parameters:

Name Type Description Default
matrix Tensor

torch.Tensor. a square positive semi-definite matrix.

required
power float

float. rank.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
491
492
493
494
495
496
497
498
499
500
501
502
503
@torch.no_grad()
def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
    r"""Compute G^{-1/p} using a SVD.

        Calculate SVD on the GPU. Sometimes, SVD on the CPU is faster than GPU, but based on the several experiments,
        CUDA seems much faster than on CPU.

    :param matrix: torch.Tensor. a square positive semi-definite matrix.
    :param power: float. rank.
    """
    u, s, vh = torch.linalg.svd(matrix, full_matrices=False)
    s.pow_(-1.0 / power)
    return u @ (s.diag() if len(matrix.shape) == 2 else s.diag_embed()) @ vh

Grafting

Graft

Base class to perform grafting onto Shampoo. This class does no grafting.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Graft:
    r"""Base class to perform grafting onto Shampoo. This class does no grafting."""

    def __init__(self, *args):
        pass

    def add_statistics(self, grad: torch.Tensor, unused_beta2: float):
        r"""Add the statistics."""
        pass

    def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
        r"""Get preconditioned gradient."""
        return grad

    def update_momentum(self, update: torch.Tensor, unused_beta1: float) -> torch.Tensor:  # noqa: ARG002
        r"""Update momentum."""
        return update

add_statistics(grad, unused_beta2)

Add the statistics.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
32
33
34
def add_statistics(self, grad: torch.Tensor, unused_beta2: float):
    r"""Add the statistics."""
    pass

precondition_gradient(grad)

Get preconditioned gradient.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
36
37
38
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
    r"""Get preconditioned gradient."""
    return grad

update_momentum(update, unused_beta1)

Update momentum.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
40
41
42
def update_momentum(self, update: torch.Tensor, unused_beta1: float) -> torch.Tensor:  # noqa: ARG002
    r"""Update momentum."""
    return update

LayerWiseGrafting

Bases: IntEnum

Layer-wise grafting.

Grafting is a technique to fix the layer-wise scale of Shampoo optimizer. https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad is already well tuned. Grafting onto Shampoo means take the Shampoo direction, but use the step magnitude from the grafted optimizer such as Adagrad or SGD.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class LayerWiseGrafting(IntEnum):
    r"""Layer-wise grafting.

    Grafting is a technique to fix the layer-wise scale of Shampoo optimizer.
    https://arxiv.org/pdf/2002.11803.pdf studies this in detail. This
    allows us to plugin the Shampoo optimizer into settings where SGD/AdaGrad
    is already well tuned. Grafting onto Shampoo means take the Shampoo direction,
    but use the step magnitude from the grafted optimizer such as Adagrad or SGD.
    """

    NONE = 0
    SGD = 1
    ADAGRAD = 2
    RMSPROP = 3
    SQRTN = 4

SGDGraft

Bases: Graft

Graft using SGD + momentum. momentum maintains an exponentially weighted moving average of gradients.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
45
46
47
48
49
50
51
52
53
54
55
class SGDGraft(Graft):
    r"""Graft using SGD + momentum. momentum maintains an exponentially weighted moving average of gradients."""

    def __init__(self, var: torch.Tensor):
        super().__init__(var)
        self.momentum: torch.Tensor = torch.zeros_like(var, device=var.device)

    def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
        r"""Update momentum."""
        self.momentum.mul_(beta1).add_(update)
        return self.momentum

update_momentum(update, beta1)

Update momentum.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
52
53
54
55
def update_momentum(self, update: torch.Tensor, beta1: float) -> torch.Tensor:
    r"""Update momentum."""
    self.momentum.mul_(beta1).add_(update)
    return self.momentum

SQRTNGraft

Bases: Graft

Graft using SQRTN.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
58
59
60
61
62
63
64
65
66
class SQRTNGraft(Graft):
    r"""Graft using SQRTN."""

    def __init__(self, var: torch.Tensor):
        super().__init__(var)

    def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
        r"""Get preconditioned gradient."""
        return grad.sign()

precondition_gradient(grad)

Get preconditioned gradient.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
64
65
66
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
    r"""Get preconditioned gradient."""
    return grad.sign()

AdaGradGraft

Bases: SGDGraft

Graft using AdaGrad. Essentially an implementation of AdaGrad with momentum.

Parameters:

Name Type Description Default
var Tensor

torch.Tensor. variable.

required
diagonal_eps float

float. diagonal epsilon.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class AdaGradGraft(SGDGraft):
    r"""Graft using AdaGrad. Essentially an implementation of AdaGrad with momentum.

    :param var: torch.Tensor. variable.
    :param diagonal_eps: float. diagonal epsilon.
    """

    def __init__(self, var: torch.Tensor, diagonal_eps: float):
        super().__init__(var)
        self.diagonal_eps = diagonal_eps
        self.statistics: torch.Tensor = torch.zeros_like(var)

    def add_statistics(self, grad: torch.Tensor, _):
        r"""Add the statistics."""
        self.statistics.add_(grad.pow(2))

    def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
        r"""Get preconditioned gradient."""
        return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)

add_statistics(grad, _)

Add the statistics.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
81
82
83
def add_statistics(self, grad: torch.Tensor, _):
    r"""Add the statistics."""
    self.statistics.add_(grad.pow(2))

precondition_gradient(grad)

Get preconditioned gradient.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
85
86
87
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
    r"""Get preconditioned gradient."""
    return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)

RMSPropGraft

Bases: SGDGraft

Graft using RMSProp. Essentially an implementation of RMSProp with momentum.

Parameters:

Name Type Description Default
var Tensor

torch.Tensor. variable.

required
diagonal_eps float

float. diagonal epsilon.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class RMSPropGraft(SGDGraft):
    r"""Graft using RMSProp. Essentially an implementation of RMSProp with momentum.

    :param var: torch.Tensor. variable.
    :param diagonal_eps: float. diagonal epsilon.
    """

    def __init__(self, var: torch.Tensor, diagonal_eps: float):
        super().__init__(var)
        self.diagonal_eps = diagonal_eps
        self.statistics: torch.Tensor = torch.zeros_like(var)

    def add_statistics(self, grad: torch.Tensor, beta2: float):
        r"""Add the statistics."""
        self.statistics.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

    def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
        r"""Get preconditioned gradient."""
        return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)

add_statistics(grad, beta2)

Add the statistics.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
102
103
104
def add_statistics(self, grad: torch.Tensor, beta2: float):
    r"""Add the statistics."""
    self.statistics.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

precondition_gradient(grad)

Get preconditioned gradient.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
106
107
108
def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
    r"""Get preconditioned gradient."""
    return grad / (torch.sqrt(self.statistics) + self.diagonal_eps)

build_graft(p, graft_type, diagonal_eps=1e-10)

Build Graft by given graft_type.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
382
383
384
385
386
387
388
389
390
391
392
def build_graft(p: torch.Tensor, graft_type: int, diagonal_eps: float = 1e-10):
    r"""Build Graft by given graft_type."""
    if graft_type == LayerWiseGrafting.ADAGRAD:
        return AdaGradGraft(p, diagonal_eps)
    if graft_type == LayerWiseGrafting.RMSPROP:
        return RMSPropGraft(p, diagonal_eps)
    if graft_type == LayerWiseGrafting.SGD:
        return SGDGraft(p)
    if graft_type == LayerWiseGrafting.SQRTN:
        return SQRTNGraft(p)
    return Graft(p)

Block Partitioner

BlockPartitioner

Partition a tensor into smaller tensors for preconditioning.

For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks,
so we effectively have 4 variables of size (1024, 512) each.

Parameters:

Name Type Description Default
var Tensor

torch.Tensor. tensor variable.

required
rank int

int. rank.

required
block_size int

int. block size.

required
pre_conditioner_type int

int type of pre-conditioner.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class BlockPartitioner:
    r"""Partition a tensor into smaller tensors for preconditioning.

        For example, if a variable has shape (4096, 512), we might split the 4096 into 4 blocks,
        so we effectively have 4 variables of size (1024, 512) each.

    :param var: torch.Tensor. tensor variable.
    :param rank: int. rank.
    :param block_size: int. block size.
    :param pre_conditioner_type: int type of pre-conditioner.
    """

    def __init__(self, var: torch.Tensor, rank: int, block_size: int, pre_conditioner_type: int):
        self.shape: List[int] = var.shape

        self.splits: List[Tuple[int, np.ndarray]] = []
        self.split_sizes: List[Tuple[int, np.ndarray]] = []

        split_sizes: List[np.ndarray] = []

        # We split var into smaller blocks. Here we store the metadata to make that split.
        for i, d in enumerate(self.shape):
            if block_size <= 0 or block_size >= d:
                split_sizes.append(np.array([d], dtype=np.int32))
                continue

            # d - 1, otherwise split appends a 0-size array.
            num_split: int = (d - 1) // block_size
            indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size

            sizes: np.ndarray = np.ones(num_split + 1, dtype=np.int32) * block_size
            sizes[-1] = d - indices[-1]

            self.splits.append((i, indices))
            self.split_sizes.append((i, sizes))
            split_sizes.append(sizes)

        self.num_splits: int = len(split_sizes)
        self.pre_conditioner_shapes: List[List[int]] = self.build_pre_conditioner_shapes(
            split_sizes, pre_conditioner_type, rank
        )

    @staticmethod
    def build_pre_conditioner_shapes(
        split_sizes: List[np.ndarray], pre_conditioner_type: int, rank: int
    ) -> List[List[int]]:
        r"""Build pre-conditioner shapes."""
        pre_conditioner_shapes: List[List[int]] = []
        for t in itertools.product(*split_sizes):
            t_shape: List[Optional[List[int]]] = [[d, d] for d in t]
            if pre_conditioner_type == PreConditionerType.INPUT:
                t_shape = t_shape[:-1] + [None]
            if pre_conditioner_type == PreConditionerType.OUTPUT:
                t_shape = [None] * (rank - 1) + t_shape[-1:]
            pre_conditioner_shapes.extend(t_shape)
        return pre_conditioner_shapes

    def shapes_for_pre_conditioners(self) -> List[List[int]]:
        r"""Get shapes of pre-conditioner."""
        return self.pre_conditioner_shapes

    @torch.no_grad()
    def partition(self, x: torch.Tensor) -> List[torch.Tensor]:
        r"""Partition tensor into blocks."""
        if x.shape != self.shape:
            raise ValueError(f'self.shape != x.shape ({self.shape} vs {x.shape})')

        tensors = [x]
        for i, sizes in self.split_sizes:
            tensors = [torch.split(t, list(sizes), dim=i) for t in tensors]
            tensors = [t for tensor in tensors for t in tensor]
        return tensors

    def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
        r"""Merge partitions back to original shape."""
        for i, indices in reversed(self.splits):
            n: int = len(indices) + 1

            partitions: List[torch.Tensor] = [
                torch.cat(partitions[idx:idx + n], dim=i) for idx in range(0, len(partitions), n)  # fmt: skip
            ]

        return partitions[0]

build_pre_conditioner_shapes(split_sizes, pre_conditioner_type, rank) staticmethod

Build pre-conditioner shapes.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@staticmethod
def build_pre_conditioner_shapes(
    split_sizes: List[np.ndarray], pre_conditioner_type: int, rank: int
) -> List[List[int]]:
    r"""Build pre-conditioner shapes."""
    pre_conditioner_shapes: List[List[int]] = []
    for t in itertools.product(*split_sizes):
        t_shape: List[Optional[List[int]]] = [[d, d] for d in t]
        if pre_conditioner_type == PreConditionerType.INPUT:
            t_shape = t_shape[:-1] + [None]
        if pre_conditioner_type == PreConditionerType.OUTPUT:
            t_shape = [None] * (rank - 1) + t_shape[-1:]
        pre_conditioner_shapes.extend(t_shape)
    return pre_conditioner_shapes

merge_partitions(partitions)

Merge partitions back to original shape.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
184
185
186
187
188
189
190
191
192
193
def merge_partitions(self, partitions: List[torch.Tensor]) -> torch.Tensor:
    r"""Merge partitions back to original shape."""
    for i, indices in reversed(self.splits):
        n: int = len(indices) + 1

        partitions: List[torch.Tensor] = [
            torch.cat(partitions[idx:idx + n], dim=i) for idx in range(0, len(partitions), n)  # fmt: skip
        ]

    return partitions[0]

partition(x)

Partition tensor into blocks.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
172
173
174
175
176
177
178
179
180
181
182
@torch.no_grad()
def partition(self, x: torch.Tensor) -> List[torch.Tensor]:
    r"""Partition tensor into blocks."""
    if x.shape != self.shape:
        raise ValueError(f'self.shape != x.shape ({self.shape} vs {x.shape})')

    tensors = [x]
    for i, sizes in self.split_sizes:
        tensors = [torch.split(t, list(sizes), dim=i) for t in tensors]
        tensors = [t for tensor in tensors for t in tensor]
    return tensors

shapes_for_pre_conditioners()

Get shapes of pre-conditioner.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
168
169
170
def shapes_for_pre_conditioners(self) -> List[List[int]]:
    r"""Get shapes of pre-conditioner."""
    return self.pre_conditioner_shapes

Pre-Conditioner

PreConditionerType

Bases: IntEnum

Type of PreConditioner.

In default (ALL), computes pre-conditioner for each dim. INPUT/OUTPUT is one-sided Shampoo, in this case only on input/output dim. Assumes last dim is always the output dim and everything else input dim.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
196
197
198
199
200
201
202
203
204
205
206
class PreConditionerType(IntEnum):
    r"""Type of PreConditioner.

    In default (ALL), computes pre-conditioner for each dim.
    INPUT/OUTPUT is one-sided Shampoo, in this case only on input/output dim.
    Assumes last dim is always the output dim and everything else input dim.
    """

    ALL = 0
    INPUT = 1
    OUTPUT = 2

PreConditioner

Compute statistics & shape from gradients for preconditioning.

Parameters:

Name Type Description Default
var Tensor

torch.Tensor. variable.

required
beta2 float

float. beta2.

required
inverse_exponent_override int

int. override inv exp.

required
block_size int

int. size of block.

required
skip_preconditioning_rank_lt int

int. skip low-rank parameter.

required
no_preconditioning_for_layers_with_dim_gt int

int. skip large size of dim of parameter.

required
shape_interpretation bool

bool. reshaping parameter.

required
pre_conditioner_type int

int. type of pre-conditioner.

required
matrix_eps float

float. epsilon of matrix.

1e-06
use_svd bool

bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.

False
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
class PreConditioner:
    r"""Compute statistics & shape from gradients for preconditioning.

    :param var: torch.Tensor. variable.
    :param beta2: float. beta2.
    :param inverse_exponent_override: int. override inv exp.
    :param block_size: int. size of block.
    :param skip_preconditioning_rank_lt: int. skip low-rank parameter.
    :param no_preconditioning_for_layers_with_dim_gt: int. skip large size of dim of parameter.
    :param shape_interpretation: bool. reshaping parameter.
    :param pre_conditioner_type: int. type of pre-conditioner.
    :param matrix_eps: float. epsilon of matrix.
    :param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
    """

    def __init__(
        self,
        var: torch.Tensor,
        beta2: float,
        inverse_exponent_override: int,
        block_size: int,
        skip_preconditioning_rank_lt: int,
        no_preconditioning_for_layers_with_dim_gt: int,
        shape_interpretation: bool,
        pre_conditioner_type: int,
        matrix_eps: float = 1e-6,
        use_svd: bool = False,
    ):
        self.beta2 = beta2
        self.inverse_exponent_override = inverse_exponent_override
        self.skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
        self.no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
        self.pre_conditioner_type = pre_conditioner_type
        self.matrix_eps = matrix_eps
        self.use_svd = use_svd

        self.w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)

        self.original_shape: List[int] = var.shape
        self.transformed_shape: List[int] = (
            merge_small_dims(self.original_shape, block_size) if shape_interpretation else var.shape
        )

        self.should_precondition_dims: List[bool] = self.get_should_precondition_dims()
        self.rank: int = sum(self.should_precondition_dims)
        self.exponent_for_pre_conditioner: int = (
            self.inverse_exponent_override if self.inverse_exponent_override > 0 else 2 * self.rank
        )

        self.statistics: Union[List[torch.Tensor], torch.Tensor] = []
        self.pre_conditioners: Union[List[torch.Tensor], torch.Tensor] = []

        self.is_same_shapes: bool = False
        if len(self.transformed_shape) > 1 and not self.skip_precondition(var):
            self.partitioner = BlockPartitioner(
                var=torch.reshape(var, self.transformed_shape),
                rank=self.rank,
                block_size=block_size,
                pre_conditioner_type=self.pre_conditioner_type,
            )

            shapes: List[Optional[List[int]]] = self.partitioner.shapes_for_pre_conditioners()
            self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
            self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
            self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1

        if self.is_same_shapes:
            self.statistics = torch.stack(self.statistics, dim=0)
            self.pre_conditioners = torch.stack(self.pre_conditioners, dim=0)

    def get_should_precondition_dims(self) -> List[bool]:
        r"""Get pre-condition dimensions by the type of conditioner."""
        if self.pre_conditioner_type == PreConditionerType.ALL or len(self.transformed_shape) <= 1:
            return [True] * len(self.transformed_shape)
        if self.pre_conditioner_type == PreConditionerType.INPUT:
            return [True] * (len(self.transformed_shape) - 1) + [False]
        if self.pre_conditioner_type == PreConditionerType.OUTPUT:
            return [False] * (len(self.transformed_shape) - 1) + [True]
        raise ValueError

    def skip_precondition(self, x: torch.Tensor) -> bool:
        return (len(x.shape) < self.skip_preconditioning_rank_lt) or any(
            dim > self.no_preconditioning_for_layers_with_dim_gt for dim in x.shape
        )

    def add_statistics(self, grad: torch.Tensor):
        r"""Compute statistics from gradients and add to the correct state entries.

        :param grad: torch.Tensor. gradient to compute statistics from.
        """
        if len(self.statistics) == 0:
            return

        reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
        partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)

        for j in range(len(partitioned_grads)):
            partitioned_grad: torch.Tensor = partitioned_grads[j]
            for i in range(self.rank):
                axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
                stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, dims=[axes, axes])
                self.statistics[j * self.rank + i].mul_(self.beta2).add_(stat, alpha=self.w2)

    def compute_pre_conditioners(self):
        r"""Compute L^{-1/exp} for each stats matrix L.

        If `self.use_svd` is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD.
        else, SVD one by one.
        If `self.use_svd` is disabled, use Schur-Newton method, which is usually much faster.
        """
        if self.use_svd and self.is_same_shapes:
            self.pre_conditioners = compute_power_svd(matrix=self.statistics, power=self.exponent_for_pre_conditioner)
            return

        for i, statistic in enumerate(self.statistics):
            self.pre_conditioners[i] = (
                compute_power_svd(matrix=statistic, power=self.exponent_for_pre_conditioner)
                if self.use_svd
                else compute_power_schur_newton(
                    mat_g=statistic, p=self.exponent_for_pre_conditioner, ridge_epsilon=self.matrix_eps
                )
            )

    @staticmethod
    def precondition_block(
        partitioned_grad: torch.Tensor,
        should_preconditioned_dims: List[bool],
        pre_conditioners_for_grad: List[torch.Tensor],
    ) -> torch.Tensor:
        r"""Perform a preconditioning operation on a single gradient block.

        Loop invariant: the dimension to be preconditioned is first
        We keep all axes in the same cyclic order they were originally.
        """
        rank: int = len(partitioned_grad.shape)
        roll: Tuple[int, ...] = (*tuple(range(1, rank)), 0)

        i: int = 0
        for should_precondition_dim in should_preconditioned_dims:
            if not should_precondition_dim:
                partitioned_grad = torch.permute(partitioned_grad, roll)
                continue

            partitioned_grad = torch.tensordot(partitioned_grad, pre_conditioners_for_grad[i], dims=[[0], [0]])
            i += 1

        return partitioned_grad

    def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
        r"""Precondition the gradient.

        :param grad: torch.Tensor. a gradient tensor to precondition.
        """
        if len(self.pre_conditioners) == 0:
            return grad

        reshaped_grad = torch.reshape(grad, self.transformed_shape)
        partitioned_grads = self.partitioner.partition(reshaped_grad)

        pre_cond_partitioned_grads: List[torch.Tensor] = [
            self.precondition_block(
                partitioned_grad,
                self.should_precondition_dims,
                self.pre_conditioners[i * self.rank:(i + 1) * self.rank]  # fmt: skip
            )
            for i, partitioned_grad in enumerate(partitioned_grads)
        ]

        merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)

        return torch.reshape(merged_grad, self.original_shape)

add_statistics(grad)

Compute statistics from gradients and add to the correct state entries.

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. gradient to compute statistics from.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def add_statistics(self, grad: torch.Tensor):
    r"""Compute statistics from gradients and add to the correct state entries.

    :param grad: torch.Tensor. gradient to compute statistics from.
    """
    if len(self.statistics) == 0:
        return

    reshaped_grad: torch.Tensor = torch.reshape(grad, self.transformed_shape)
    partitioned_grads: List[torch.Tensor] = self.partitioner.partition(reshaped_grad)

    for j in range(len(partitioned_grads)):
        partitioned_grad: torch.Tensor = partitioned_grads[j]
        for i in range(self.rank):
            axes: List[int] = [ax for ax in range(partitioned_grad.ndim) if ax != i]
            stat: torch.Tensor = torch.tensordot(partitioned_grad, partitioned_grad, dims=[axes, axes])
            self.statistics[j * self.rank + i].mul_(self.beta2).add_(stat, alpha=self.w2)

compute_pre_conditioners()

Compute L^{-1/exp} for each stats matrix L.

If self.use_svd is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD. else, SVD one by one. If self.use_svd is disabled, use Schur-Newton method, which is usually much faster.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def compute_pre_conditioners(self):
    r"""Compute L^{-1/exp} for each stats matrix L.

    If `self.use_svd` is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD.
    else, SVD one by one.
    If `self.use_svd` is disabled, use Schur-Newton method, which is usually much faster.
    """
    if self.use_svd and self.is_same_shapes:
        self.pre_conditioners = compute_power_svd(matrix=self.statistics, power=self.exponent_for_pre_conditioner)
        return

    for i, statistic in enumerate(self.statistics):
        self.pre_conditioners[i] = (
            compute_power_svd(matrix=statistic, power=self.exponent_for_pre_conditioner)
            if self.use_svd
            else compute_power_schur_newton(
                mat_g=statistic, p=self.exponent_for_pre_conditioner, ridge_epsilon=self.matrix_eps
            )
        )

get_should_precondition_dims()

Get pre-condition dimensions by the type of conditioner.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
279
280
281
282
283
284
285
286
287
def get_should_precondition_dims(self) -> List[bool]:
    r"""Get pre-condition dimensions by the type of conditioner."""
    if self.pre_conditioner_type == PreConditionerType.ALL or len(self.transformed_shape) <= 1:
        return [True] * len(self.transformed_shape)
    if self.pre_conditioner_type == PreConditionerType.INPUT:
        return [True] * (len(self.transformed_shape) - 1) + [False]
    if self.pre_conditioner_type == PreConditionerType.OUTPUT:
        return [False] * (len(self.transformed_shape) - 1) + [True]
    raise ValueError

precondition_block(partitioned_grad, should_preconditioned_dims, pre_conditioners_for_grad) staticmethod

Perform a preconditioning operation on a single gradient block.

Loop invariant: the dimension to be preconditioned is first We keep all axes in the same cyclic order they were originally.

Source code in pytorch_optimizer/optimizer/shampoo_utils.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@staticmethod
def precondition_block(
    partitioned_grad: torch.Tensor,
    should_preconditioned_dims: List[bool],
    pre_conditioners_for_grad: List[torch.Tensor],
) -> torch.Tensor:
    r"""Perform a preconditioning operation on a single gradient block.

    Loop invariant: the dimension to be preconditioned is first
    We keep all axes in the same cyclic order they were originally.
    """
    rank: int = len(partitioned_grad.shape)
    roll: Tuple[int, ...] = (*tuple(range(1, rank)), 0)

    i: int = 0
    for should_precondition_dim in should_preconditioned_dims:
        if not should_precondition_dim:
            partitioned_grad = torch.permute(partitioned_grad, roll)
            continue

        partitioned_grad = torch.tensordot(partitioned_grad, pre_conditioners_for_grad[i], dims=[[0], [0]])
        i += 1

    return partitioned_grad

preconditioned_grad(grad)

Precondition the gradient.

Parameters:

Name Type Description Default
grad Tensor

torch.Tensor. a gradient tensor to precondition.

required
Source code in pytorch_optimizer/optimizer/shampoo_utils.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
    r"""Precondition the gradient.

    :param grad: torch.Tensor. a gradient tensor to precondition.
    """
    if len(self.pre_conditioners) == 0:
        return grad

    reshaped_grad = torch.reshape(grad, self.transformed_shape)
    partitioned_grads = self.partitioner.partition(reshaped_grad)

    pre_cond_partitioned_grads: List[torch.Tensor] = [
        self.precondition_block(
            partitioned_grad,
            self.should_precondition_dims,
            self.pre_conditioners[i * self.rank:(i + 1) * self.rank]  # fmt: skip
        )
        for i, partitioned_grad in enumerate(partitioned_grads)
    ]

    merged_grad = self.partitioner.merge_partitions(pre_cond_partitioned_grads)

    return torch.reshape(merged_grad, self.original_shape)