Skip to content

Learning Rate Scheduler

deberta_v3_large_lr_scheduler(model, layer_low_threshold=195, layer_middle_threshold=323, head_param_start=390, base_lr=2e-05, head_lr=0.0001, wd=0.01)

DeBERTa-v3 large layer-wise learning rate scheduler.

Reference: https://github.com/gilfernandes/commonlit

Parameters:

Name Type Description Default
model Module

Model based on Huggingface Transformers.

required
layer_low_threshold int

Index where the lower 12 layers start.

195
layer_middle_threshold int

Index where the middle 24 layers end.

323
head_param_start int

Starting index of the head parameters (end of backbone).

390
base_lr float

Base learning rate for backbone layers.

2e-05
head_lr float

Learning rate for head layers.

0.0001
wd float

Weight decay.

0.01
Source code in pytorch_optimizer/lr_scheduler/experimental/deberta_v3_lr_scheduler.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def deberta_v3_large_lr_scheduler(
    model: nn.Module,
    layer_low_threshold: int = 195,
    layer_middle_threshold: int = 323,
    head_param_start: int = 390,
    base_lr: float = 2e-5,
    head_lr: float = 1e-4,
    wd: float = 1e-2,
) -> ParamsT:
    r"""DeBERTa-v3 large layer-wise learning rate scheduler.

    Reference: https://github.com/gilfernandes/commonlit

    Args:
        model (nn.Module): Model based on Huggingface Transformers.
        layer_low_threshold (int): Index where the lower 12 layers start.
        layer_middle_threshold (int): Index where the middle 24 layers end.
        head_param_start (int): Starting index of the head parameters (end of backbone).
        base_lr (float): Base learning rate for backbone layers.
        head_lr (float): Learning rate for head layers.
        wd (float): Weight decay.

    """
    named_parameters = list(model.named_parameters())

    backbone_parameters = named_parameters[:head_param_start]
    head_parameters = named_parameters[head_param_start:]

    head_group = [params for (_, params) in head_parameters]

    parameters = [{'params': head_group, 'lr': head_lr}]

    for layer_num, (name, params) in enumerate(backbone_parameters):
        weight_decay: float = 0.0 if ('bias' in name) or ('LayerNorm.weight' in name) else wd

        lr = base_lr / 2.5  # 2e-5
        if layer_num >= layer_middle_threshold:
            lr = base_lr / 0.5  # 1e-4
        elif layer_num >= layer_low_threshold:
            lr = base_lr

        parameters.append({'params': params, 'weight_decay': weight_decay, 'lr': lr})

    return parameters

get_chebyshev_schedule(optimizer, num_epochs, is_warmup=False, last_epoch=-1)

Get Chebyshev learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
num_epochs int

Number of total epochs.

required
is_warmup bool

Whether it is the warm-up stage.

False
last_epoch int

The index of the last epoch when resuming training.

-1
Source code in pytorch_optimizer/lr_scheduler/chebyshev.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def get_chebyshev_schedule(
    optimizer: Optimizer, num_epochs: int, is_warmup: bool = False, last_epoch: int = -1
) -> LRScheduler:
    """Get Chebyshev learning rate scheduler.

    Args:
        optimizer (Optimizer): The optimizer for which to schedule the learning rate.
        num_epochs (int): Number of total epochs.
        is_warmup (bool): Whether it is the warm-up stage.
        last_epoch (int): The index of the last epoch when resuming training.

    """
    lr_scheduler = partial(get_chebyshev_lr_lambda, num_epochs=num_epochs, is_warmup=is_warmup)

    return LambdaLR(optimizer, lr_scheduler, last_epoch)

get_wsd_schedule(optimizer, num_warmup_steps, num_stable_steps, num_decay_steps, min_lr_ratio=0.0, num_cycles=0.5, cooldown_type='1-sqrt', last_epoch=-1)

Get Warmup-Stable-Decay (WSD) learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer for which to schedule the learning rate.

required
num_warmup_steps int

The number of warmup steps.

required
num_stable_steps int

The number of stable steps.

required
num_decay_steps int

The number of decay steps.

required
min_lr_ratio float

The minimum learning rate as a ratio of the initial learning rate.

0.0
num_cycles float

The number of waves in the cosine schedule (default is a half-cosine decay).

0.5
cooldown_type COOLDOWN_TYPE

Cooldown type of the learning rate scheduler.

'1-sqrt'
last_epoch int

The index of the last epoch when resuming training.

-1
Source code in pytorch_optimizer/lr_scheduler/wsd.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_wsd_schedule(
    optimizer: Optimizer,
    num_warmup_steps: int,
    num_stable_steps: int,
    num_decay_steps: int,
    min_lr_ratio: float = 0.0,
    num_cycles: float = 0.5,
    cooldown_type: COOLDOWN_TYPE = '1-sqrt',
    last_epoch: int = -1,
) -> LRScheduler:
    r"""Get Warmup-Stable-Decay (WSD) learning rate scheduler.

    Args:
        optimizer (Optimizer): The optimizer for which to schedule the learning rate.
        num_warmup_steps (int): The number of warmup steps.
        num_stable_steps (int): The number of stable steps.
        num_decay_steps (int): The number of decay steps.
        min_lr_ratio (float): The minimum learning rate as a ratio of the initial learning rate.
        num_cycles (float): The number of waves in the cosine schedule (default is a half-cosine decay).
        cooldown_type (COOLDOWN_TYPE): Cooldown type of the learning rate scheduler.
        last_epoch (int): The index of the last epoch when resuming training.

    """
    lr_scheduler = partial(
        get_wsd_scheduler_lambda,
        num_warmup_steps=num_warmup_steps,
        num_stable_steps=num_stable_steps,
        num_decay_steps=num_decay_steps,
        min_lr_ratio=min_lr_ratio,
        num_cycles=num_cycles,
        cooldown_type=cooldown_type,
    )

    return LambdaLR(optimizer, lr_scheduler, last_epoch)

CosineAnnealingWarmupRestarts

Bases: LRScheduler

CosineAnnealingWarmupRestarts.

Parameters:

Name Type Description Default
optimizer Optimizer

Wrapped optimizer instance.

required
first_cycle_steps int

Number of steps in the first cycle.

required
cycle_mult float

Cycle steps magnification factor.

1.0
max_lr float

Maximum learning rate.

0.0001
min_lr float

Minimum learning rate.

1e-06
warmup_steps int

Number of warmup steps.

0
gamma float

Decrease rate of max learning rate by cycle.

0.9
last_epoch int

The index of the last epoch for resuming training.

-1
Source code in pytorch_optimizer/lr_scheduler/cosine_anealing.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class CosineAnnealingWarmupRestarts(LRScheduler):
    r"""CosineAnnealingWarmupRestarts.

    Args:
        optimizer (Optimizer): Wrapped optimizer instance.
        first_cycle_steps (int): Number of steps in the first cycle.
        cycle_mult (float): Cycle steps magnification factor.
        max_lr (float): Maximum learning rate.
        min_lr (float): Minimum learning rate.
        warmup_steps (int): Number of warmup steps.
        gamma (float): Decrease rate of max learning rate by cycle.
        last_epoch (int): The index of the last epoch for resuming training.

    """

    def __init__(
        self,
        optimizer: Optimizer,
        first_cycle_steps: int,
        cycle_mult: float = 1.0,
        max_lr: float = 1e-4,
        min_lr: float = 1e-6,
        warmup_steps: int = 0,
        gamma: float = 0.9,
        last_epoch: int = -1,
    ):
        if warmup_steps >= first_cycle_steps:
            raise ValueError(
                f'warmup_steps must be smaller than first_cycle_steps. {warmup_steps} < {first_cycle_steps}'
            )

        self.first_cycle_steps = first_cycle_steps
        self.cycle_mult = cycle_mult
        self.base_max_lr = max_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.gamma = gamma
        self.cur_cycle_steps = first_cycle_steps
        self.step_in_cycle = last_epoch
        self.last_epoch = last_epoch

        self.cycle: int = 0
        self.base_lrs: List[float] = []

        super().__init__(optimizer, last_epoch)

        self.init_lr()

    def init_lr(self) -> None:
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self) -> List[float]:
        if self.step_in_cycle == -1:
            return self.base_lrs

        if self.step_in_cycle < self.warmup_steps:
            return [
                (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs
            ]

        return [
            base_lr
            + (self.max_lr - base_lr)
            * (
                1
                + math.cos(
                    math.pi * (self.step_in_cycle - self.warmup_steps) / (self.cur_cycle_steps - self.warmup_steps)
                )
            )
            / 2.0
            for base_lr in self.base_lrs
        ]

    def step(self, epoch: Optional[int] = None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = (
                    int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
                )
        elif epoch >= self.first_cycle_steps:
            if self.cycle_mult == 1.0:
                self.step_in_cycle = epoch % self.first_cycle_steps
                self.cycle = epoch // self.first_cycle_steps
            else:
                n: int = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
                self.cycle = n
                self.step_in_cycle = epoch - int(
                    self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)
                )  # fmt: skip
                self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** n  # fmt: skip
        else:
            self.cur_cycle_steps = self.first_cycle_steps
            self.step_in_cycle = epoch

        self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)  # fmt: skip
        self.last_epoch = math.floor(epoch)

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

CosineScheduler

Bases: BaseLinearWarmupScheduler

Cosine LR scheduler with linear warmup.

Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
17
18
19
20
21
22
class CosineScheduler(BaseLinearWarmupScheduler):
    """Cosine LR scheduler with linear warmup."""

    def _step(self) -> float:
        phase: float = (self.step_t - self.warmup_steps) / (self.total_steps - self.warmup_steps) * math.pi
        return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0

get_supported_lr_schedulers(filters=None)

Return list of available lr scheduler names, sorted alphabetically.

:param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will return the whole list.

Source code in pytorch_optimizer/lr_scheduler/__init__.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def get_supported_lr_schedulers(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
    r"""Return list of available lr scheduler names, sorted alphabetically.

    :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will
        return the whole list.
    """
    if filters is None:
        return sorted(LR_SCHEDULERS.keys())

    include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]

    filtered_list: Set[str] = set()
    for include_filter in include_filters:
        filtered_list.update(fnmatch.filter(LR_SCHEDULERS.keys(), include_filter))

    return sorted(filtered_list)

LinearScheduler

Bases: BaseLinearWarmupScheduler

Linear LR scheduler with linear warmup.

Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
 8
 9
10
11
12
13
14
class LinearScheduler(BaseLinearWarmupScheduler):
    """Linear LR scheduler with linear warmup."""

    def _step(self) -> float:
        return self.max_lr + (self.min_lr - self.max_lr) * (self.step_t - self.warmup_steps) / (
            self.total_steps - self.warmup_steps
        )

load_lr_scheduler(lr_scheduler_name)

Load learning rate scheduler.

:param lr_scheduler_name: learning rate scheduler name.

Source code in pytorch_optimizer/lr_scheduler/__init__.py
72
73
74
75
76
77
78
79
80
81
82
def load_lr_scheduler(lr_scheduler_name: str) -> SchedulerClass:
    r"""Load learning rate scheduler.

    :param lr_scheduler_name: learning rate scheduler name.
    """
    lrs_name: str = lr_scheduler_name.lower()

    if lrs_name not in LR_SCHEDULERS:
        raise NotImplementedError(f'not implemented lr_scheduler {lrs_name}')

    return LR_SCHEDULERS[lrs_name]

PolyScheduler

Bases: BaseLinearWarmupScheduler

Poly LR Scheduler.

Parameters:

Name Type Description Default
poly_order float

lr scheduler decreases with steps.

0.5
Source code in pytorch_optimizer/lr_scheduler/linear_warmup.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class PolyScheduler(BaseLinearWarmupScheduler):
    """Poly LR Scheduler.

    Args:
        poly_order (float): lr scheduler decreases with steps.

    """

    def __init__(self, optimizer, poly_order: float = 0.5, **kwargs):
        self.poly_order = poly_order

        if poly_order <= 0:
            raise ValueError(f'poly_order must be positive. {poly_order}')

        super().__init__(optimizer, **kwargs)

    def _step(self) -> float:
        return self.min_lr + (self.max_lr - self.min_lr) * (self.step_t - self.warmup_steps) ** self.poly_order

ProportionScheduler

ProportionScheduler (Rho Scheduler of GSAM).

This scheduler outputs a value that evolves proportionally to a given learning rate scheduler.

Parameters:

Name Type Description Default
lr_scheduler LRScheduler

Learning rate scheduler.

required
max_lr float

Maximum learning rate.

required
min_lr float

Minimum learning rate.

0.0
max_value float

Maximum value of rho.

2.0
min_value float

Minimum value of rho.

2.0
Source code in pytorch_optimizer/lr_scheduler/proportion.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class ProportionScheduler:
    """ProportionScheduler (Rho Scheduler of GSAM).

    This scheduler outputs a value that evolves proportionally to a given learning rate scheduler.

    Args:
        lr_scheduler (LRScheduler): Learning rate scheduler.
        max_lr (float): Maximum learning rate.
        min_lr (float): Minimum learning rate.
        max_value (float): Maximum value of rho.
        min_value (float): Minimum value of rho.

    """

    def __init__(
        self,
        lr_scheduler: LRScheduler,
        max_lr: float,
        min_lr: float = 0.0,
        max_value: float = 2.0,
        min_value: float = 2.0,
    ):
        self.lr_scheduler = lr_scheduler
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.max_value = max_value
        self.min_value = min_value

        self.step_t: int = 0
        self.last_lr: List[float] = []

        self.step()

    def get_lr(self) -> float:
        return self.last_lr[0]

    def step(self) -> float:
        self.step_t += 1

        if hasattr(self.lr_scheduler, 'last_lr'):
            lr = self.lr_scheduler.last_lr[0]
        else:
            lr = self.lr_scheduler.optimizer.param_groups[0]['lr']

        if self.max_lr > self.min_lr:
            value = self.min_value + (self.max_value - self.min_value) * (lr - self.min_lr) / (
                self.max_lr - self.min_lr
            )
        else:
            value = self.max_value

        self.last_lr = [value]

        return value

REXScheduler

Bases: LRScheduler

Revisiting Budgeted Training with an Improved Schedule.

Parameters:

Name Type Description Default
optimizer Optimizer

Wrapped optimizer instance.

required
total_steps int

Number of steps to optimize.

required
max_lr float

Maximum learning rate.

1.0
min_lr float

Minimum learning rate.

0.0
Source code in pytorch_optimizer/lr_scheduler/rex.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class REXScheduler(LRScheduler):
    """Revisiting Budgeted Training with an Improved Schedule.

    Args:
        optimizer (Optimizer): Wrapped optimizer instance.
        total_steps (int): Number of steps to optimize.
        max_lr (float): Maximum learning rate.
        min_lr (float): Minimum learning rate.

    """

    def __init__(
        self,
        optimizer: Optimizer,
        total_steps: int,
        max_lr: float = 1.0,
        min_lr: float = 0.0,
    ):
        self.total_steps = total_steps
        self.max_lr = max_lr
        self.min_lr = min_lr

        self.step_t: int = 0
        self.base_lrs: List[float] = []

        # record current value in self._last_lr to match API from torch.optim.lr_scheduler
        self.last_lr: List[float] = [self.max_lr]

        super().__init__(optimizer)

        self.init_lr()

    def init_lr(self) -> None:
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self) -> float:
        return self.last_lr[0]

    def get_linear_lr(self) -> float:
        if self.step_t >= self.total_steps:
            return self.min_lr

        progress: float = self.step_t / self.total_steps

        return self.min_lr + (self.max_lr - self.min_lr) * ((1.0 - progress) / (1.0 - progress / 2.0))

    def step(self, epoch: Optional[int] = None) -> float:
        value: float = self.get_linear_lr()

        self.step_t += 1

        if self.optimizer is not None:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = value

        self.last_lr = [value]

        return value