Skip to content

Learning Rate Scheduler

deberta_v3_large_lr_scheduler(model, layer_low_threshold=195, layer_middle_threshold=323, head_param_start=390, base_lr=2e-05, head_lr=0.0001, wd=0.01)

DeBERTa-v3 large layer-wise lr scheduler.

Reference : https://github.com/gilfernandes/commonlit.

Parameters:

Name Type Description Default
model Module

nn.Module. model. based on Huggingface Transformers.

required
layer_low_threshold int

int. start of the 12 layers.

195
layer_middle_threshold int

int. end of the 24 layers.

323
head_param_start int

int. where the backbone ends (head starts).

390
base_lr float

float. base lr.

2e-05
head_lr float

float. head_lr.

0.0001
wd float

float. weight decay.

0.01
Source code in pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def deberta_v3_large_lr_scheduler(
    model: nn.Module,
    layer_low_threshold: int = 195,
    layer_middle_threshold: int = 323,
    head_param_start: int = 390,
    base_lr: float = 2e-5,
    head_lr: float = 1e-4,
    wd: float = 1e-2,
) -> PARAMETERS:
    """DeBERTa-v3 large layer-wise lr scheduler.

        Reference : https://github.com/gilfernandes/commonlit.

    :param model: nn.Module. model. based on Huggingface Transformers.
    :param layer_low_threshold: int. start of the 12 layers.
    :param layer_middle_threshold: int. end of the 24 layers.
    :param head_param_start: int. where the backbone ends (head starts).
    :param base_lr: float. base lr.
    :param head_lr: float. head_lr.
    :param wd: float. weight decay.
    """
    named_parameters = list(model.named_parameters())

    backbone_parameters = named_parameters[:head_param_start]
    head_parameters = named_parameters[head_param_start:]

    head_group = [params for (_, params) in head_parameters]

    parameters = [{'params': head_group, 'lr': head_lr}]

    for layer_num, (name, params) in enumerate(backbone_parameters):
        weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd

        lr = base_lr / 2.5  # 2e-5
        if layer_num >= layer_middle_threshold:
            lr = base_lr / 0.5  # 1e-4
        elif layer_num >= layer_low_threshold:
            lr = base_lr

        parameters.append({'params': params, 'weight_decay': weight_decay, 'lr': lr})

    return parameters

get_chebyshev_lr(lr, epoch, num_epochs, is_warmup=False)

Get chebyshev learning rate.

Parameters:

Name Type Description Default
lr float

float. learning rate.

required
epoch int

int. current epochs.

required
num_epochs int

int. number of total epochs.

required
is_warmup bool

bool. whether warm-up stage or not.

False
Source code in pytorch_optimizer/lr_scheduler/chebyshev.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def get_chebyshev_lr(lr: float, epoch: int, num_epochs: int, is_warmup: bool = False) -> float:
    r"""Get chebyshev learning rate.

    :param lr: float. learning rate.
    :param epoch: int. current epochs.
    :param num_epochs: int. number of total epochs.
    :param is_warmup: bool. whether warm-up stage or not.
    """
    if is_warmup:
        return lr

    epoch_power: int = np.power(2, int(np.log2(num_epochs - 1)) + 1) if num_epochs > 1 else 1
    scheduler = get_chebyshev_schedule(epoch_power)

    idx: int = epoch - 2
    if idx < 0:
        idx = 0
    elif idx > len(scheduler) - 1:
        idx = len(scheduler) - 1

    chebyshev_value: float = scheduler[idx]

    return lr * chebyshev_value

CosineAnnealingWarmupRestarts

Bases: _LRScheduler

CosineAnnealingWarmupRestarts.

Parameters:

Name Type Description Default
optimizer OPTIMIZER

Optimizer. wrapped optimizer instance.

required
first_cycle_steps int

int. first cycle step size.

required
cycle_mult float

float. cycle steps magnification.

1.0
max_lr float

float.

0.0001
min_lr float

float.

1e-06
warmup_steps int

int. number of warmup steps.

0
gamma float

float. decrease rate of lr by cycle.

0.9
last_epoch int

int. step size of the current cycle.

-1
Source code in pytorch_optimizer/lr_scheduler/cosine_anealing.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class CosineAnnealingWarmupRestarts(_LRScheduler):
    r"""CosineAnnealingWarmupRestarts.

    :param optimizer: Optimizer. wrapped optimizer instance.
    :param first_cycle_steps: int. first cycle step size.
    :param cycle_mult: float. cycle steps magnification.
    :param max_lr: float.
    :param min_lr: float.
    :param warmup_steps: int. number of warmup steps.
    :param gamma: float. decrease rate of lr by cycle.
    :param last_epoch: int. step size of the current cycle.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER,
        first_cycle_steps: int,
        cycle_mult: float = 1.0,
        max_lr: float = 1e-4,
        min_lr: float = 1e-6,
        warmup_steps: int = 0,
        gamma: float = 0.9,
        last_epoch: int = -1,
    ):
        if warmup_steps >= first_cycle_steps:
            raise ValueError(
                f'[-] warmup_steps must be smaller than first_cycle_steps. {warmup_steps} < {first_cycle_steps}'
            )

        self.first_cycle_steps = first_cycle_steps
        self.cycle_mult = cycle_mult
        self.base_max_lr = max_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.gamma = gamma
        self.cur_cycle_steps = first_cycle_steps
        self.step_in_cycle = last_epoch
        self.last_epoch = last_epoch

        self.cycle: int = 0
        self.base_lrs: List[float] = []

        super().__init__(optimizer, last_epoch)

        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self) -> List[float]:
        if self.step_in_cycle == -1:
            return self.base_lrs

        if self.step_in_cycle < self.warmup_steps:
            return [
                (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs
            ]

        return [
            base_lr
            + (self.max_lr - base_lr)
            * (
                1
                + math.cos(
                    math.pi * (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps)
                )
            )
            / 2.0
            for base_lr in self.base_lrs
        ]

    def step(self, epoch: Optional[int] = None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = (
                    int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
                )
        elif epoch >= self.first_cycle_steps:
            if self.cycle_mult == 1.0:
                self.step_in_cycle = epoch % self.first_cycle_steps
                self.cycle = epoch // self.first_cycle_steps
            else:
                n: int = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                self.cycle = n
                self.step_in_cycle = epoch - int(
                    self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)
                )  # fmt: skip
                self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n  # fmt: skip
        else:
            self.cur_cycle_steps = self.first_cycle_steps
            self.step_in_cycle = epoch

        self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)  # fmt: skip
        self.last_epoch = math.floor(epoch)

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

LinearScheduler

Bases: BaseLinearWarmupScheduler

Linear LR Scheduler w/ linear warmup.

Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
 8
 9
10
11
12
13
14
class LinearScheduler(BaseLinearWarmupScheduler):
    r"""Linear LR Scheduler w/ linear warmup."""

    def _step(self) -> float:
        return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / (
            self.total_steps - self.warmup_steps
        )

CosineScheduler

Bases: BaseLinearWarmupScheduler

Cosine LR Scheduler w/ linear warmup.

Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
17
18
19
20
21
22
class CosineScheduler(BaseLinearWarmupScheduler):
    r"""Cosine LR Scheduler w/ linear warmup."""

    def _step(self) -> float:
        phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
        return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0

PolyScheduler

Bases: BaseLinearWarmupScheduler

Poly LR Scheduler.

Parameters:

Name Type Description Default
poly_order float

float. lr scheduler decreases with steps.

0.5
Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class PolyScheduler(BaseLinearWarmupScheduler):
    r"""Poly LR Scheduler.

    :param poly_order: float. lr scheduler decreases with steps.
    """

    def __init__(self, poly_order: float = 0.5, **kwargs):
        self.poly_order = poly_order

        if poly_order <= 0:
            raise ValueError(f'[-] poly_order must be positive. {poly_order}')

        super().__init__(**kwargs)

    def _step(self) -> float:
        return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order

ProportionScheduler

ProportionScheduler (Rho Scheduler of GSAM).

This scheduler outputs a value that evolves proportional to lr_scheduler.

Parameters:

Name Type Description Default
lr_scheduler

learning rate scheduler.

required
max_lr float

float. maximum lr.

required
min_lr float

float. minimum lr.

0.0
max_value float

float. maximum of rho.

2.0
min_value float

float. minimum of rho.

2.0
Source code in pytorch_optimizer/lr_scheduler/proportion.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class ProportionScheduler:
    r"""ProportionScheduler (Rho Scheduler of GSAM).

        This scheduler outputs a value that evolves proportional to lr_scheduler.

    :param lr_scheduler: learning rate scheduler.
    :param max_lr: float. maximum lr.
    :param min_lr: float. minimum lr.
    :param max_value: float. maximum of rho.
    :param min_value: float. minimum of rho.
    """

    def __init__(
        self, lr_scheduler, max_lr: float, min_lr: float = 0.0, max_value: float = 2.0, min_value: float = 2.0
    ):
        self.lr_scheduler = lr_scheduler
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.max_value = max_value
        self.min_value = min_value

        self.step_t: int = 0
        self.last_lr: List[float] = []

        self.step()

    def get_lr(self) -> float:
        return self.last_lr[0]

    def step(self) -> float:
        self.step_t += 1

        if hasattr(self.lr_scheduler, 'last_lr'):
            lr = self.lr_scheduler.last_lr[0]
        else:
            lr = self.lr_scheduler.optimizer.param_groups[0]['lr']

        if self.max_lr > self.min_lr:
            value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (
                self.max_lr - self.min_lr
            )
        else:
            value = self.max_value

        self.last_lr = [value]

        return value

REXScheduler

Bases: _LRScheduler

Revisiting Budgeted Training with an Improved Schedule.

Parameters:

Name Type Description Default
optimizer OPTIMIZER

Optimizer. wrapped optimizer instance.

required
total_steps int

int. number of steps to optimize.

required
max_lr float

float. max lr.

1.0
min_lr float

float. min lr.

0.0
Source code in pytorch_optimizer/lr_scheduler/rex.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class REXScheduler(_LRScheduler):
    r"""Revisiting Budgeted Training with an Improved Schedule.

    :param optimizer: Optimizer. wrapped optimizer instance.
    :param total_steps: int. number of steps to optimize.
    :param max_lr: float. max lr.
    :param min_lr: float. min lr.
    """

    def __init__(
        self,
        optimizer: OPTIMIZER,
        total_steps: int,
        max_lr: float = 1.0,
        min_lr: float = 0.0,
    ):
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr

        self.step_t: int = 0
        self.base_lrs: List[float] = []

        # record current value in self._last_lr to match API from torch.optim.lr_scheduler
        self.last_lr: List[float] = [self.max_lr]

        super().__init__(optimizer)

        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self) -> float:
        return self.last_lr[0]

    def get_linear_lr(self) -> float:
        if self.step_t >= self.total_steps:
            return self.min_lr

        progress: float = self.step_t / self.total_steps

        return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0))

    def step(self):
        value: float = self.get_linear_lr()

        self.step_t += 1

        if self.optimizer is not None:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = value

        self.last_lr = [value]

        return value