Skip to content

Loss Function

bi_tempered_logistic_loss(activations, labels, t1, t2, label_smooth=0.0, num_iters=5, reduction='mean')

Bi-Tempered Logistic Loss.

Parameters:

Name Type Description Default
activations Tensor

A multidimensional tensor with last dimension num_classes.

required
labels Tensor

Tensor with the same shape and dtype as activations (one-hot encoded), or a long tensor with one dimension less (class indices).

required
t1 float

Temperature 1 (< 1.0 for boundedness of loss).

required
t2 float

Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

Label smoothing parameter, between 0 and 1.

0.0
num_iters int

Number of iterations to run the normalization method.

5
reduction str

Specifies reduction method to apply to output: 'none', 'mean', or 'sum'.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def bi_tempered_logistic_loss(
    activations: torch.Tensor,
    labels: torch.Tensor,
    t1: float,
    t2: float,
    label_smooth: float = 0.0,
    num_iters: int = 5,
    reduction: str = 'mean',
) -> torch.Tensor:
    r"""Bi-Tempered Logistic Loss.

    Args:
        activations (torch.Tensor): A multidimensional tensor with last dimension `num_classes`.
        labels (torch.Tensor): Tensor with the same shape and dtype as activations (one-hot encoded),
            or a long tensor with one dimension less (class indices).
        t1 (float): Temperature 1 (< 1.0 for boundedness of loss).
        t2 (float): Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
        label_smooth (float): Label smoothing parameter, between 0 and 1.
        num_iters (int): Number of iterations to run the normalization method.
        reduction (str): Specifies reduction method to apply to output: 'none', 'mean', or 'sum'.

    """
    if len(labels.shape) < len(activations.shape):
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smooth > 0:
        num_classes: int = labels_onehot.shape[-1]
        labels_onehot = (1.0 - label_smooth * num_classes / (num_classes - 1)) * labels_onehot + label_smooth / (
            num_classes - 1
        )

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = (
        labels_onehot * log_t(labels_onehot + 1e-10, t1)
        - labels_onehot * log_t(probabilities, t1)
        - labels_onehot.pow(2.0 - t1) / (2.0 - t1)
        + probabilities.pow(2.0 - t1) / (2.0 - t1)
    )
    loss_values = loss_values.sum(dim=-1)

    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()
    return loss_values

BCEFocalLoss

Bases: Module

BCEFocal loss function with probability input.

Parameters:

Name Type Description Default
alpha float

Weighting factor for class imbalance, commonly set to 0.25.

0.25
gamma float

Focusing parameter to reduce loss contribution of easy examples.

2.0
label_smooth float

Smoothness constant to regularize target labels.

0.0
eps float

Small epsilon to avoid numerical instability.

1e-06
reduction str

Specifies reduction type to apply to output: 'none', 'mean' or 'sum'.

'mean'
Source code in pytorch_optimizer/loss/focal.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class BCEFocalLoss(nn.Module):
    """BCEFocal loss function with probability input.

    Args:
        alpha (float): Weighting factor for class imbalance, commonly set to 0.25.
        gamma (float): Focusing parameter to reduce loss contribution of easy examples.
        label_smooth (float): Smoothness constant to regularize target labels.
        eps (float): Small epsilon to avoid numerical instability.
        reduction (str): Specifies reduction type to apply to output: 'none', 'mean' or 'sum'.

    """

    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        label_smooth: float = 0.0,
        eps: float = 1e-6,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

        self.bce = BCELoss(label_smooth=label_smooth, eps=eps, reduction='none')

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        bce_loss = self.bce(y_pred, y_true)
        focal_loss = (
            y_true * self.alpha * (1.0 - y_pred) ** self.gamma * bce_loss
            + (1.0 - y_true) ** self.gamma * bce_loss
        )  # fmt: skip

        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

BCELoss

Bases: Module

Binary Cross Entropy loss with label smoothing and probability input.

Parameters:

Name Type Description Default
label_smooth float

Smoothness constant to soften target labels.

0.0
eps float

Small epsilon to avoid numerical instability.

1e-06
reduction str

Specifies the reduction to apply to the output; 'none' | 'mean' | 'sum'.

'mean'
Source code in pytorch_optimizer/loss/cross_entropy.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class BCELoss(nn.Module):
    """Binary Cross Entropy loss with label smoothing and probability input.

    Args:
        label_smooth (float): Smoothness constant to soften target labels.
        eps (float): Small epsilon to avoid numerical instability.
        reduction (str): Specifies the reduction to apply to the output; 'none' | 'mean' | 'sum'.

    """

    def __init__(self, label_smooth: float = 0.0, eps: float = 1e-6, reduction: str = 'mean'):
        super().__init__()
        self.label_smooth = label_smooth
        self.eps = eps
        self.reduction = reduction

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.training and self.label_smooth > 0.0:
            y_true = (1.0 - self.label_smooth) * y_true + self.label_smooth / y_pred.size(-1)
        y_pred = torch.clamp(y_pred, self.eps, 1.0 - self.eps)
        return binary_cross_entropy(y_pred, y_true, reduction=self.reduction)

BinaryBiTemperedLogisticLoss

Bases: Module

Bi-Tempered Logistic Loss for Binary Classification.

Parameters:

Name Type Description Default
t1 float

Temperature 1 (< 1.0 for boundedness of the loss).

required
t2 float

Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

Label smoothing parameter between 0 and 1.

0.0
ignore_index Optional[int]

Specifies a target value that is ignored and does not contribute to the input gradient.

None
reduction str

Specifies the reduction to apply to the output: 'none', 'mean', or 'sum'.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
class BinaryBiTemperedLogisticLoss(nn.Module):
    """Bi-Tempered Logistic Loss for Binary Classification.

    Args:
        t1 (float): Temperature 1 (< 1.0 for boundedness of the loss).
        t2 (float): Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
        label_smooth (float): Label smoothing parameter between 0 and 1.
        ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute
            to the input gradient.
        reduction (str): Specifies the reduction to apply to the output: 'none', 'mean', or 'sum'.

    """

    def __init__(
        self,
        t1: float,
        t2: float,
        label_smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.t1 = t1
        self.t2 = t2
        self.label_smooth = label_smooth
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if predictions.size(1) != 1 or targets.size(1) != 1:
            raise ValueError('Channel dimension for predictions and targets must be equal to 1')

        loss = bi_tempered_logistic_loss(
            torch.cat((-predictions, predictions), dim=1).moveaxis(1, -1),
            torch.cat((1.0 - targets, targets), dim=1).moveaxis(1, -1),
            t1=self.t1,
            t2=self.t2,
            label_smooth=self.label_smooth,
            reduction='none',
        ).unsqueeze(dim=1)

        if self.ignore_index is not None:
            mask = targets.eq(self.ignore_index)
            loss = torch.masked_fill(loss, mask, value=0)

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

BiTemperedLogisticLoss

Bases: Module

Bi-Tempered Log Loss.

Reference

https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/bitempered_loss.py

Parameters:

Name Type Description Default
t1 float

Temperature 1 (< 1.0 for boundedness).

required
t2 float

Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

Label smoothing parameter between 0 and 1.

0.0
ignore_index Optional[int]

Index to ignore during loss calculation.

None
reduction str

Type of reduction to apply to output, e.g. 'mean', 'sum', or 'none'.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class BiTemperedLogisticLoss(nn.Module):
    """Bi-Tempered Log Loss.

    Reference:
        https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/bitempered_loss.py

    Args:
        t1 (float): Temperature 1 (< 1.0 for boundedness).
        t2 (float): Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
        label_smooth (float): Label smoothing parameter between 0 and 1.
        ignore_index (Optional[int]): Index to ignore during loss calculation.
        reduction (str): Type of reduction to apply to output, e.g. 'mean', 'sum', or 'none'.

    """

    def __init__(
        self,
        t1: float,
        t2: float,
        label_smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.t1 = t1
        self.t2 = t2
        self.label_smooth = label_smooth
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        loss = bi_tempered_logistic_loss(
            predictions, targets, t1=self.t1, t2=self.t2, label_smooth=self.label_smooth, reduction='none'
        )

        if self.ignore_index is not None:
            mask = ~targets.eq(self.ignore_index)
            loss *= mask

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

DiceLoss

Bases: _Loss

Dice loss for image segmentation task.

Reference

https://github.com/BloodAxe/pytorch-toolbelt

Parameters:

Name Type Description Default
mode ClassMode

Loss mode - 'binary', 'multiclass', or 'multilabel'.

'binary'
classes Optional[List[int]]

List of classes to include in loss computation. Defaults to all classes.

None
log_loss bool

If True, loss is computed as -log(dice_coeff); otherwise 1 - dice_coeff.

False
from_logits bool

If True, assumes input is raw logits.

True
label_smooth float

Smoothness constant for dice coefficient numerator and denominator.

0.0
ignore_index Optional[int]

Label to ignore during loss computation.

None
eps float

Small epsilon for numerical stability.

1e-06
Source code in pytorch_optimizer/loss/dice.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class DiceLoss(_Loss):
    """Dice loss for image segmentation task.

    Reference:
        https://github.com/BloodAxe/pytorch-toolbelt

    Args:
        mode (ClassMode): Loss mode - 'binary', 'multiclass', or 'multilabel'.
        classes (Optional[List[int]]): List of classes to include in loss computation. Defaults to all classes.
        log_loss (bool): If True, loss is computed as `-log(dice_coeff)`; otherwise `1 - dice_coeff`.
        from_logits (bool): If True, assumes input is raw logits.
        label_smooth (float): Smoothness constant for dice coefficient numerator and denominator.
        ignore_index (Optional[int]): Label to ignore during loss computation.
        eps (float): Small epsilon for numerical stability.

    """

    def __init__(
        self,
        mode: ClassMode = 'binary',
        classes: Optional[List[int]] = None,
        log_loss: bool = False,
        from_logits: bool = True,
        label_smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        eps: float = 1e-6,
    ):
        super().__init__()

        if classes is not None and mode == 'binary':
            raise ValueError('masking classes is not supported with mode=binary')

        self.mode = mode
        self.classes = classes
        self.from_logits = from_logits
        self.label_smooth = label_smooth
        self.eps = eps
        self.log_loss = log_loss
        self.ignore_index = ignore_index

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            y_pred = y_pred.log_softmax(dim=1).exp() if self.mode == 'multiclass' else logsigmoid(y_pred).exp()

        bs: int = y_true.size(0)
        num_classes: int = y_pred.size(1)

        dims: Tuple[int, ...] = (0, 2)

        if self.mode == 'binary':
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        if self.mode == 'multiclass':
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask.unsqueeze(1)

                y_true = one_hot((y_true * mask).to(torch.long), num_classes)
                y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)
            else:
                y_true = one_hot(y_true, num_classes)
                y_true = y_true.permute(0, 2, 1)

        if self.mode == 'multilabel':
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        scores = self.compute_score(
            y_pred, y_true.type_as(y_pred), label_smooth=self.label_smooth, eps=self.eps, dims=dims
        )

        loss = -torch.log(scores.clamp_min(self.eps)) if self.log_loss else 1.0 - scores

        # Dice loss is undefined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

        if self.classes is not None:
            loss = loss[self.classes]

        return self.aggregate_loss(loss)

    @staticmethod
    def aggregate_loss(loss: torch.Tensor) -> torch.Tensor:
        return loss.mean()

    @staticmethod
    def compute_score(
        output: torch.Tensor,
        target: torch.Tensor,
        label_smooth: float = 0.0,
        eps: float = 1e-6,
        dims: Optional[Tuple[int, ...]] = None,
    ) -> torch.Tensor:
        return soft_dice_score(output, target, label_smooth, eps, dims)

FocalCosineLoss

Bases: Module

Focal Cosine Loss function with logits input.

Parameters:

Name Type Description Default
alpha float

Weighting factor for class imbalance.

1.0
gamma float

Focusing parameter to reduce loss contribution from easy examples.

2.0
focal_weight float

Weight of the focal loss component in the combined loss.

0.1
reduction str

Specifies the reduction to apply to the output: 'none', 'mean', or 'sum'.

'mean'
Source code in pytorch_optimizer/loss/focal.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class FocalCosineLoss(nn.Module):
    """Focal Cosine Loss function with logits input.

    Args:
        alpha (float): Weighting factor for class imbalance.
        gamma (float): Focusing parameter to reduce loss contribution from easy examples.
        focal_weight (float): Weight of the focal loss component in the combined loss.
        reduction (str): Specifies the reduction to apply to the output: 'none', 'mean', or 'sum'.

    """

    def __init__(self, alpha: float = 1.0, gamma: float = 2.0, focal_weight: float = 0.1, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.focal_weight = focal_weight
        self.reduction = reduction

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        cosine_loss = cosine_embedding_loss(
            y_pred,
            one_hot(y_true, num_classes=y_pred.size(-1)),
            torch.tensor([1], device=y_true.device),
            reduction=self.reduction,
        )

        ce_loss = cross_entropy(normalize(y_pred), y_true, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()

        return cosine_loss + self.focal_weight * focal_loss

FocalLoss

Bases: Module

Focal Loss function with logits input.

Parameters:

Name Type Description Default
alpha float

Weighting factor for class imbalance.

1.0
gamma float

Focusing parameter to down-weight easy examples and focus training on hard negatives.

2.0
Source code in pytorch_optimizer/loss/focal.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class FocalLoss(nn.Module):
    """Focal Loss function with logits input.

    Args:
        alpha (float): Weighting factor for class imbalance.
        gamma (float): Focusing parameter to down-weight easy examples and focus training on hard negatives.

    """

    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        bce_loss = binary_cross_entropy_with_logits(y_pred, y_true, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()

FocalTverskyLoss

Bases: Module

Focal Tversky Loss with logits input.

Parameters:

Name Type Description Default
alpha float

Weight for false negatives in Tversky index.

0.5
beta float

Weight for false positives in Tversky index.

0.5
gamma float

Focusing parameter that shapes the loss to focus more on hard examples.

1.0
smooth float

Smoothing factor to avoid division by zero.

1e-06
Source code in pytorch_optimizer/loss/focal.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class FocalTverskyLoss(nn.Module):
    """Focal Tversky Loss with logits input.

    Args:
        alpha (float): Weight for false negatives in Tversky index.
        beta (float): Weight for false positives in Tversky index.
        gamma (float): Focusing parameter that shapes the loss to focus more on hard examples.
        smooth (float): Smoothing factor to avoid division by zero.

    """

    def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, smooth: float = 1e-6):
        super().__init__()
        self.gamma = gamma

        self.tversky = TverskyLoss(alpha, beta, smooth)

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        return self.tversky(y_pred, y_true) ** self.gamma

get_supported_loss_functions(filters=None)

Return list of available loss function names, sorted alphabetically.

:param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will return the whole list.

Source code in pytorch_optimizer/loss/__init__.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_supported_loss_functions(filters: Optional[Union[str, List[str]]] = None) -> List[str]:
    r"""Return list of available loss function names, sorted alphabetically.

    :param filters: Optional[Union[str, List[str]]]. wildcard filter string that works with fmatch. if None, it will
        return the whole list.
    """
    if filters is None:
        return sorted(LOSS_FUNCTIONS.keys())

    include_filters: Sequence[str] = filters if isinstance(filters, (tuple, list)) else [filters]

    filtered_list: Set[str] = set()
    for include_filter in include_filters:
        filtered_list.update(fnmatch.filter(LOSS_FUNCTIONS.keys(), include_filter))

    return sorted(filtered_list)

JaccardLoss

Bases: _Loss

Jaccard loss for image segmentation.

Reference: https://github.com/BloodAxe/pytorch-toolbelt

Parameters:

Name Type Description Default
mode str

Loss mode, one of 'binary', 'multiclass', or 'multilabel'.

required
classes Optional[List[int]]

List of classes to include in the loss computation, defaults to all classes if None.

None
log_loss bool

If True, loss is computed as -log(jaccard); otherwise, 1 - jaccard.

False
from_logits bool

If True, input is raw logits, which will be converted to probabilities.

True
label_smooth float

Label smoothing constant.

0.0
eps float

Small number to prevent division by zero.

1e-06
Source code in pytorch_optimizer/loss/jaccard.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class JaccardLoss(_Loss):
    r"""Jaccard loss for image segmentation.

    Reference: https://github.com/BloodAxe/pytorch-toolbelt

    Args:
        mode (str): Loss mode, one of 'binary', 'multiclass', or 'multilabel'.
        classes (Optional[List[int]]): List of classes to include in the loss computation,
            defaults to all classes if None.
        log_loss (bool): If True, loss is computed as -log(jaccard);
            otherwise, 1 - jaccard.
        from_logits (bool): If True, input is raw logits, which will be converted to probabilities.
        label_smooth (float): Label smoothing constant.
        eps (float): Small number to prevent division by zero.

    """

    def __init__(
        self,
        mode: ClassMode,
        classes: Optional[List[int]] = None,
        log_loss: bool = False,
        from_logits: bool = True,
        label_smooth: float = 0.0,
        eps: float = 1e-6,
    ):
        super().__init__()

        if classes is not None and mode == 'binary':
            raise ValueError('masking classes is not supported with mode=binary')

        self.mode = mode
        self.classes = classes
        self.log_loss = log_loss
        self.from_logits = from_logits
        self.label_smooth = label_smooth
        self.eps = eps

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            y_pred = y_pred.log_softmax(dim=1).exp() if self.mode == 'multiclass' else logsigmoid(y_pred).exp()

        bs: int = y_true.size(0)
        num_classes: int = y_pred.size(1)

        dims: Tuple[int, ...] = (0, 2)

        if self.mode == 'binary':
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

        if self.mode == 'multiclass':
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            y_true = one_hot(y_true, num_classes)
            y_true = y_true.permute(0, 2, 1)

        if self.mode == 'multilabel':
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

        scores = soft_jaccard_score(
            y_pred, y_true.type(y_pred.dtype), label_smooth=self.label_smooth, eps=self.eps, dims=dims
        )

        loss = -torch.log(scores.clamp_min(self.eps)) if self.log_loss else 1.0 - scores

        # IoU loss is defined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.float()

        if self.classes is not None:
            loss = loss[self.classes]

        return loss.mean()

LDAMLoss

Bases: Module

Label-Distribution-Aware Margin (LDAM) Loss.

Parameters:

Name Type Description Default
num_class_list List[int]

List of the number of samples per class.

required
max_m float

Maximum margin (the C term in the paper).

0.5
weight Optional[Tensor]

Optional class weights for re-weighting.

None
s float

Scaling factor for logits.

30.0
Source code in pytorch_optimizer/loss/ldam.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class LDAMLoss(nn.Module):
    r"""Label-Distribution-Aware Margin (LDAM) Loss.

    Args:
        num_class_list (List[int]): List of the number of samples per class.
        max_m (float): Maximum margin (the `C` term in the paper).
        weight (Optional[torch.Tensor]): Optional class weights for re-weighting.
        s (float): Scaling factor for logits.

    """

    def __init__(
        self, num_class_list: List[int], max_m: float = 0.5, weight: Optional[torch.Tensor] = None, s: float = 30.0
    ):
        super().__init__()

        cls_num_list: torch.Tensor = torch.FloatTensor(num_class_list)
        m_list: torch.Tensor = 1.0 / cls_num_list.sqrt_().sqrt_()
        m_list *= max_m / max(m_list)

        self.m_list = m_list.unsqueeze(0)
        self.weight = weight
        self.s = s

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        index = torch.zeros_like(y_pred, dtype=torch.bool)
        index.scatter_(1, y_true.view(-1, 1), 1)

        batch_m = torch.matmul(self.m_list.to(index.device), index.float().transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        x_m = y_pred - batch_m

        output = torch.where(index, x_m, y_pred)
        return cross_entropy(self.s * output, y_true, weight=self.weight)

LovaszHingeLoss

Bases: Module

Binary Lovasz hinge loss.

Parameters:

Name Type Description Default
per_image bool

compute the loss per image instead of per batch.

True
Source code in pytorch_optimizer/loss/lovasz.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class LovaszHingeLoss(nn.Module):
    r"""Binary Lovasz hinge loss.

    Args:
        per_image (bool): compute the loss per image instead of per batch.

    """

    def __init__(self, per_image: bool = True):
        super().__init__()
        self.per_image = per_image

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if not self.per_image:
            return lovasz_hinge_flat(y_pred, y_true)

        losses = torch.stack([lovasz_hinge_flat(y_p, y_t) for y_p, y_t in zip(y_pred, y_true)])
        return losses.mean()

soft_dice_score(output, target, label_smooth=0.0, eps=1e-06, dims=None)

Get soft dice score.

Parameters:

Name Type Description Default
output Tensor

Predicted segmentation probabilities.

required
target Tensor

Ground truth segmentation masks.

required
label_smooth float

Label smoothing factor to avoid zero denominators.

0.0
eps float

Small epsilon for numerical stability.

1e-06
dims Optional[Tuple[int, ...]]

Dimensions over which to reduce when computing score.

None
Source code in pytorch_optimizer/loss/dice.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def soft_dice_score(
    output: torch.Tensor,
    target: torch.Tensor,
    label_smooth: float = 0.0,
    eps: float = 1e-6,
    dims: Optional[Tuple[int, ...]] = None,
) -> torch.Tensor:
    """Get soft dice score.

    Args:
        output (torch.Tensor): Predicted segmentation probabilities.
        target (torch.Tensor): Ground truth segmentation masks.
        label_smooth (float): Label smoothing factor to avoid zero denominators.
        eps (float): Small epsilon for numerical stability.
        dims (Optional[Tuple[int, ...]]): Dimensions over which to reduce when computing score.

    """
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)

    return (2.0 * intersection + label_smooth) / (cardinality + label_smooth).clamp_min(eps)

soft_jaccard_score(output, target, label_smooth=0.0, eps=1e-06, dims=None)

Get soft Jaccard score.

Parameters:

Name Type Description Default
output Tensor

Predicted segments (probabilities or logits).

required
target Tensor

Ground truth segments.

required
label_smooth float

Label smoothing factor to avoid zero denominators.

0.0
eps float

Small epsilon for numerical stability.

1e-06
dims Optional[Tuple[int, ...]]

Dimensions to reduce over when computing the score.

None
Source code in pytorch_optimizer/loss/jaccard.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def soft_jaccard_score(
    output: torch.Tensor,
    target: torch.Tensor,
    label_smooth: float = 0.0,
    eps: float = 1e-6,
    dims: Optional[Tuple[int, ...]] = None,
) -> torch.Tensor:
    r"""Get soft Jaccard score.

    Args:
        output (torch.Tensor): Predicted segments (probabilities or logits).
        target (torch.Tensor): Ground truth segments.
        label_smooth (float): Label smoothing factor to avoid zero denominators.
        eps (float): Small epsilon for numerical stability.
        dims (Optional[Tuple[int, ...]]): Dimensions to reduce over when computing the score.

    """
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)

    return (intersection + label_smooth) / (cardinality - intersection + label_smooth).clamp_min(eps)

SoftF1Loss

Bases: Module

Soft-F1 loss.

Parameters:

Name Type Description Default
beta float

The beta parameter in the F-beta score, balancing precision vs recall.

1.0
eps float

Small epsilon value to avoid division by zero during calculation.

1e-06
Source code in pytorch_optimizer/loss/f1.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class SoftF1Loss(nn.Module):
    """Soft-F1 loss.

    Args:
        beta (float): The beta parameter in the F-beta score, balancing precision vs recall.
        eps (float): Small epsilon value to avoid division by zero during calculation.

    """

    def __init__(self, beta: float = 1.0, eps: float = 1e-6):
        super().__init__()
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        tp = (y_true * y_pred).sum().float()
        fn = ((1 - y_true) * y_pred).sum().float()
        fp = (y_true * (1 - y_pred)).sum().float()

        p = tp / (tp + fp + self.eps)
        r = tp / (tp + fn + self.eps)

        f1 = (1 + self.beta ** 2) * (p * r) / ((self.beta ** 2) * p + r + self.eps)  # fmt: skip
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)

        return 1.0 - f1.mean()

TverskyLoss

Bases: Module

Tversky Loss with logits input.

Parameters:

Name Type Description Default
alpha float

Weight of false positives.

0.5
beta float

Weight of false negatives.

0.5
smooth float

Small constant to avoid division by zero.

1e-06
Source code in pytorch_optimizer/loss/tversky.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class TverskyLoss(nn.Module):
    """Tversky Loss with logits input.

    Args:
        alpha (float): Weight of false positives.
        beta (float): Weight of false negatives.
        smooth (float): Small constant to avoid division by zero.

    """

    def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        y_pred = torch.sigmoid(y_pred)

        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        tp = (y_pred * y_true).sum()
        fp = ((1.0 - y_true) * y_pred).sum()
        fn = (y_true * (1.0 - y_pred)).sum()

        loss = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)

        return 1.0 - loss