Skip to content

Loss Function

bi_tempered_logistic_loss(activations, labels, t1, t2, label_smooth=0.0, num_iters=5, reduction='mean')

Bi-Tempered Logistic Loss.

Parameters:

Name Type Description Default
activations Tensor

torch.Tensor. A multidimensional tensor with last dimension num_classes.

required
labels Tensor

torch.Tensor. A tensor with shape and dtype as activations (onehot), or a long tensor of one dimension less than activations (pytorch standard)

required
t1 float

float. Temperature 1 (< 1.0 for boundedness).

required
t2 float

float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

float. Label smoothing parameter between [0, 1).

0.0
num_iters int

int. Number of iterations to run the method.

5
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def bi_tempered_logistic_loss(
    activations: torch.Tensor,
    labels: torch.Tensor,
    t1: float,
    t2: float,
    label_smooth: float = 0.0,
    num_iters: int = 5,
    reduction: str = 'mean',
) -> torch.Tensor:
    """Bi-Tempered Logistic Loss.

    :param activations: torch.Tensor. A multidimensional tensor with last dimension `num_classes`.
    :param labels: torch.Tensor. A tensor with shape and dtype as activations (onehot), or a long tensor of
        one dimension less than activations (pytorch standard)
    :param t1: float. Temperature 1 (< 1.0 for boundedness).
    :param t2: float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    :param label_smooth: float. Label smoothing parameter between [0, 1).
    :param num_iters: int. Number of iterations to run the method.
    :param reduction: str. type of reduction.
    """
    if len(labels.shape) < len(activations.shape):
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smooth > 0:
        num_classes: int = labels_onehot.shape[-1]
        labels_onehot = (1.0 - label_smooth * num_classes / (num_classes - 1)) * labels_onehot + label_smooth / (
            num_classes - 1
        )

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = (
        labels_onehot * log_t(labels_onehot + 1e-10, t1)
        - labels_onehot * log_t(probabilities, t1)
        - labels_onehot.pow(2.0 - t1) / (2.0 - t1)
        + probabilities.pow(2.0 - t1) / (2.0 - t1)
    )
    loss_values = loss_values.sum(dim=-1)

    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()
    return loss_values

BiTemperedLogisticLoss

Bases: Module

Bi-Tempered Log Loss.

Reference : https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/bitempered_loss.py

Parameters:

Name Type Description Default
t1 float

float. Temperature 1 (< 1.0 for boundedness).

required
t2 float

float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

float. Label smoothing parameter between [0, 1).

0.0
ignore_index Optional[int]

Optional[int]. Index to ignore.

None
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class BiTemperedLogisticLoss(nn.Module):
    """Bi-Tempered Log Loss.

    Reference : https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/bitempered_loss.py

    :param t1: float. Temperature 1 (< 1.0 for boundedness).
    :param t2: float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    :param label_smooth: float. Label smoothing parameter between [0, 1).
    :param ignore_index: Optional[int]. Index to ignore.
    :param reduction: str. type of reduction.
    """

    def __init__(
        self,
        t1: float,
        t2: float,
        label_smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.t1 = t1
        self.t2 = t2
        self.label_smooth = label_smooth
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        loss = bi_tempered_logistic_loss(
            predictions, targets, t1=self.t1, t2=self.t2, label_smooth=self.label_smooth, reduction='none'
        )

        if self.ignore_index is not None:
            mask = ~targets.eq(self.ignore_index)
            loss *= mask

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

BinaryBiTemperedLogisticLoss

Bases: Module

Modification of BiTemperedLogisticLoss for binary classification case.

Parameters:

Name Type Description Default
t1 float

float. Temperature 1 (< 1.0 for boundedness).

required
t2 float

float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).

required
label_smooth float

float. Label smoothing parameter between [0, 1).

0.0
ignore_index Optional[int]

Optional[int]. Index to ignore.

None
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/bi_tempered.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class BinaryBiTemperedLogisticLoss(nn.Module):
    """Modification of BiTemperedLogisticLoss for binary classification case.

    :param t1: float. Temperature 1 (< 1.0 for boundedness).
    :param t2: float. Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    :param label_smooth: float. Label smoothing parameter between [0, 1).
    :param ignore_index: Optional[int]. Index to ignore.
    :param reduction: str. type of reduction.
    """

    def __init__(
        self,
        t1: float,
        t2: float,
        label_smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.t1 = t1
        self.t2 = t2
        self.label_smooth = label_smooth
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        if predictions.size(1) != 1 or targets.size(1) != 1:
            raise ValueError('Channel dimension for predictions and targets must be equal to 1')

        loss = bi_tempered_logistic_loss(
            torch.cat((-predictions, predictions), dim=1).moveaxis(1, -1),
            torch.cat((1.0 - targets, targets), dim=1).moveaxis(1, -1),
            t1=self.t1,
            t2=self.t2,
            label_smooth=self.label_smooth,
            reduction='none',
        ).unsqueeze(dim=1)

        if self.ignore_index is not None:
            mask = targets.eq(self.ignore_index)
            loss = torch.masked_fill(loss, mask, value=0)

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()
        return loss

BCELoss

Bases: Module

binary cross entropy with label smoothing + probability input.

Parameters:

Name Type Description Default
label_smooth float

float. Smoothness constant for dice coefficient (a).

0.0
eps float

float. epsilon.

1e-06
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/cross_entropy.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class BCELoss(nn.Module):
    r"""binary cross entropy with label smoothing + probability input.

    :param label_smooth: float. Smoothness constant for dice coefficient (a).
    :param eps: float. epsilon.
    :param reduction: str. type of reduction.
    """

    def __init__(self, label_smooth: float = 0.0, eps: float = 1e-6, reduction: str = 'mean'):
        super().__init__()
        self.label_smooth = label_smooth
        self.eps = eps
        self.reduction = reduction

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.training and self.label_smooth > 0.0:
            y_true = (1.0 - self.label_smooth) * y_true + self.label_smooth / y_pred.size(-1)
        y_pred = torch.clamp(y_pred, self.eps, 1.0 - self.eps)
        return binary_cross_entropy(y_pred, y_true, reduction=self.reduction)

SoftF1Loss

Bases: Module

Soft-F1 loss.

Parameters:

Name Type Description Default
beta float

float. f-beta.

1.0
eps float

float. epsilon.

1e-06
Source code in pytorch_optimizer/loss/f1.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class SoftF1Loss(nn.Module):
    r"""Soft-F1 loss.

    :param beta: float. f-beta.
    :param eps: float. epsilon.
    """

    def __init__(self, beta: float = 1.0, eps: float = 1e-6):
        super().__init__()
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        tp = (y_true * y_pred).sum().float()
        fn = ((1 - y_true) * y_pred).sum().float()
        fp = (y_true * (1 - y_pred)).sum().float()

        p = tp / (tp + fp + self.eps)
        r = tp / (tp + fn + self.eps)

        f1 = (1 + self.beta ** 2) * (p * r) / ((self.beta ** 2) * p + r + self.eps)  # fmt: skip
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1), f1)

        return 1.0 - f1.mean()

FocalLoss

Bases: Module

Focal loss function w/ logit input.

Parameters:

Name Type Description Default
alpha float

float. alpha.

1.0
gamma float

float. gamma.

2.0
Source code in pytorch_optimizer/loss/focal.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class FocalLoss(nn.Module):
    r"""Focal loss function w/ logit input.

    :param alpha: float. alpha.
    :param gamma: float. gamma.
    """

    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        bce_loss = binary_cross_entropy_with_logits(y_pred, y_true, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()

FocalCosineLoss

Bases: Module

Focal Cosine loss function w/ logit input.

Parameters:

Name Type Description Default
alpha float

float. alpha.

1.0
gamma float

float. gamma.

2.0
focal_weight float

float. weight of focal loss.

0.1
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/focal.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class FocalCosineLoss(nn.Module):
    r"""Focal Cosine loss function w/ logit input.

    :param alpha: float. alpha.
    :param gamma: float. gamma.
    :param focal_weight: float. weight of focal loss.
    :param reduction: str. type of reduction.
    """

    def __init__(self, alpha: float = 1.0, gamma: float = 2.0, focal_weight: float = 0.1, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.focal_weight = focal_weight
        self.reduction = reduction

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        cosine_loss = cosine_embedding_loss(
            y_pred,
            one_hot(y_true, num_classes=y_pred.size(-1)),
            torch.tensor([1], device=y_true.device),
            reduction=self.reduction,
        )

        ce_loss = cross_entropy(normalize(y_pred), y_true, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()

        return cosine_loss + self.focal_weight * focal_loss

BCEFocalLoss

Bases: Module

BCEFocal loss function w/ probability input.

Parameters:

Name Type Description Default
alpha float

float. alpha.

0.25
gamma float

float. gamma.

2.0
label_smooth float

float. Smoothness constant for dice coefficient (a).

0.0
eps float

float. epsilon.

1e-06
reduction str

str. type of reduction.

'mean'
Source code in pytorch_optimizer/loss/focal.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class BCEFocalLoss(nn.Module):
    r"""BCEFocal loss function w/ probability input.

    :param alpha: float. alpha.
    :param gamma: float. gamma.
    :param label_smooth: float. Smoothness constant for dice coefficient (a).
    :param eps: float. epsilon.
    :param reduction: str. type of reduction.
    """

    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        label_smooth: float = 0.0,
        eps: float = 1e-6,
        reduction: str = 'mean',
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

        self.bce = BCELoss(label_smooth=label_smooth, eps=eps, reduction='none')

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        bce_loss = self.bce(y_pred, y_true)
        focal_loss = (
            y_true * self.alpha * (1.0 - y_pred) ** self.gamma * bce_loss
            + (1.0 - y_true) ** self.gamma * bce_loss
        )  # fmt: skip

        return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()

FocalTverskyLoss

Bases: Module

Focal Tversky Loss w/ logits input.

Parameters:

Name Type Description Default
alpha float

float. alpha.

0.5
beta float

float. beta.

0.5
gamma float

float. gamma.

1.0
smooth float

float. smooth factor.

1e-06
Source code in pytorch_optimizer/loss/focal.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class FocalTverskyLoss(nn.Module):
    r"""Focal Tversky Loss w/ logits input.

    :param alpha: float. alpha.
    :param beta: float. beta.
    :param gamma: float. gamma.
    :param smooth: float. smooth factor.
    """

    def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, smooth: float = 1e-6):
        super().__init__()
        self.gamma = gamma

        self.tversky = TverskyLoss(alpha, beta, smooth)

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        return self.tversky(y_pred, y_true) ** self.gamma

soft_jaccard_score(output, target, label_smooth=0.0, eps=1e-06, dims=None)

Get soft jaccard score.

Parameters:

Name Type Description Default
output Tensor

torch.Tensor. predicted segments.

required
target Tensor

torch.Tensor. ground truth segments.

required
label_smooth float

float. label smoothing factor.

0.0
eps float

float. epsilon.

1e-06
dims Optional[Tuple[int, ...]]

Optional[Tuple[int, ...]]. target dimensions to reduce.

None
Source code in pytorch_optimizer/loss/jaccard.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def soft_jaccard_score(
    output: torch.Tensor,
    target: torch.Tensor,
    label_smooth: float = 0.0,
    eps: float = 1e-6,
    dims: Optional[Tuple[int, ...]] = None,
) -> torch.Tensor:
    r"""Get soft jaccard score.

    :param output: torch.Tensor. predicted segments.
    :param target: torch.Tensor. ground truth segments.
    :param label_smooth: float. label smoothing factor.
    :param eps: float. epsilon.
    :param dims: Optional[Tuple[int, ...]]. target dimensions to reduce.
    """
    if dims is not None:
        intersection = torch.sum(output * target, dim=dims)
        cardinality = torch.sum(output + target, dim=dims)
    else:
        intersection = torch.sum(output * target)
        cardinality = torch.sum(output + target)

    return (intersection + label_smooth) / (cardinality - intersection + label_smooth).clamp_min(eps)

JaccardLoss

Bases: _Loss

Jaccard loss for image segmentation task. It supports binary, multiclass and multilabel cases.

Reference : https://github.com/BloodAxe/pytorch-toolbelt

Parameters:

Name Type Description Default
mode CLASS_MODE

CLASS_MODE. loss mode 'binary', 'multiclass', or 'multilabel.

required
classes List[int]

Optional[List[int]]. List of classes that contribute in loss computation. By default, all channels are included.

None
log_loss bool

If True, loss computed as -log(jaccard); otherwise 1 - jaccard

False
from_logits bool

bool. If True, assumes input is raw logits.

True
label_smooth float

float. Smoothness constant for dice coefficient (a).

0.0
eps float

float. epsilon.

1e-06
Source code in pytorch_optimizer/loss/jaccard.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class JaccardLoss(_Loss):
    r"""Jaccard loss for image segmentation task. It supports binary, multiclass and multilabel cases.

    Reference : https://github.com/BloodAxe/pytorch-toolbelt

    :param mode: CLASS_MODE. loss mode 'binary', 'multiclass', or 'multilabel.
    :param classes: Optional[List[int]]. List of classes that contribute in loss computation. By default,
        all channels are included.
    :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
    :param from_logits: bool. If True, assumes input is raw logits.
    :param label_smooth: float. Smoothness constant for dice coefficient (a).
    :param eps: float. epsilon.
    """

    def __init__(
        self,
        mode: CLASS_MODE,
        classes: List[int] = None,
        log_loss: bool = False,
        from_logits: bool = True,
        label_smooth: float = 0.0,
        eps: float = 1e-6,
    ):
        super().__init__()

        if classes is not None:
            if mode == 'binary':
                raise ValueError('[-] Masking classes is not supported with mode=binary')

            classes = torch.LongTensor(classes)

        self.mode = mode
        self.classes = classes
        self.log_loss = log_loss
        self.from_logits = from_logits
        self.label_smooth = label_smooth
        self.eps = eps

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            y_pred = y_pred.log_softmax(dim=1).exp() if self.mode == 'multiclass' else logsigmoid(y_pred).exp()

        bs: int = y_true.size(0)
        num_classes: int = y_pred.size(1)

        dims: Tuple[int, ...] = (0, 2)

        if self.mode == 'binary':
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

        if self.mode == 'multiclass':
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            y_true = one_hot(y_true, num_classes)
            y_true = y_true.permute(0, 2, 1)

        if self.mode == 'multilabel':
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

        scores = soft_jaccard_score(
            y_pred, y_true.type(y_pred.dtype), label_smooth=self.label_smooth, eps=self.eps, dims=dims
        )

        loss = -torch.log(scores.clamp_min(self.eps)) if self.log_loss else 1.0 - scores

        # IoU loss is defined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.float()

        if self.classes is not None:
            loss = loss[self.classes]

        return loss.mean()

LDAMLoss

Bases: Module

LDAM Loss.

Parameters:

Name Type Description Default
num_class_list List[int]

List[int]. list of number of class.

required
max_m float

float. max margin (C term in the paper).

0.5
weight Optional[Tensor]

Optional[torch.Tensor]. class weight.

None
s float

float. scaler.

30.0
Source code in pytorch_optimizer/loss/ldam.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class LDAMLoss(nn.Module):
    r"""LDAM Loss.

    :param num_class_list: List[int]. list of number of class.
    :param max_m: float. max margin (`C` term in the paper).
    :param weight: Optional[torch.Tensor]. class weight.
    :param s: float. scaler.
    """

    def __init__(
        self, num_class_list: List[int], max_m: float = 0.5, weight: Optional[torch.Tensor] = None, s: float = 30.0
    ):
        super().__init__()

        cls_num_list = np.asarray(num_class_list)
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list *= max_m / np.max(m_list)

        self.m_list = torch.FloatTensor(m_list).unsqueeze(0)
        self.weight = weight
        self.s = s

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        index = torch.zeros_like(y_pred, dtype=torch.bool)
        index.scatter_(1, y_true.view(-1, 1), 1)

        batch_m = torch.matmul(self.m_list.to(index.device), index.float().transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        x_m = y_pred - batch_m

        output = torch.where(index, x_m, y_pred)
        return cross_entropy(self.s * output, y_true, weight=self.weight)

LovaszHingeLoss

Bases: Module

Binary Lovasz hinge loss.

Parameters:

Name Type Description Default
per_image bool

bool. compute the loss per image instead of per batch.

True
Source code in pytorch_optimizer/loss/lovasz.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class LovaszHingeLoss(nn.Module):
    r"""Binary Lovasz hinge loss.

    :param per_image: bool. compute the loss per image instead of per batch.
    """

    def __init__(self, per_image: bool = True):
        super().__init__()
        self.per_image = per_image

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if not self.per_image:
            return lovasz_hinge_flat(y_pred, y_true)
        return sum(lovasz_hinge_flat(y_p, y_t) for y_p, y_t in zip(y_pred, y_true)) / y_pred.size()[0]

TverskyLoss

Bases: Module

Tversky Loss w/ logits input.

Parameters:

Name Type Description Default
alpha float

float. alpha.

0.5
beta float

float. beta.

0.5
smooth float

float. smooth factor.

1e-06
Source code in pytorch_optimizer/loss/tversky.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class TverskyLoss(nn.Module):
    r"""Tversky Loss w/ logits input.

    :param alpha: float. alpha.
    :param beta: float. beta.
    :param smooth: float. smooth factor.
    """

    def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.smooth = smooth

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        y_pred = torch.sigmoid(y_pred)

        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)

        tp = (y_pred * y_true).sum()
        fp = ((1.0 - y_true) * y_pred).sum()
        fn = (y_true * (1.0 - y_pred)).sum()

        loss = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)

        return 1.0 - loss