AsymmetricLoss¶
- class torch_ecg.models.loss.AsymmetricLoss(gamma_neg: numbers.Real = 4, gamma_pos: numbers.Real = 1, prob_margin: float = 0.05, disable_torch_grad_focal_loss: bool = False, reduction: str = 'mean', implementation: str = 'alibaba-miil')[source]¶
Bases:
torch.nn.modules.module.Module
Asymmetric loss class.
The asymmetric loss is proposed in 1, with official implementation in 2. The asymmetric loss is defined as
\[ASL = \begin{cases} L_+ := (1-p)^{\gamma_+} \log(p) \ L_- := (p_m)^{\gamma_-} \log(1-p_m) \end{cases}\]where \(p_m = \max(p-m, 0)\) is the shifted probability, with probability margin \(m\). The loss on one label of one sample is
\[L = -yL_+ - (1-y)L_-\]- Parameters
gamma_neg (numbers.Real, default 4) – Exponent of the multiplier to the negative loss.
gamma_pos (numbers.Real, default 1) – Exponent of the multiplier to the positive loss.
prob_margin (float, default 0.05) – The probability margin
disable_torch_grad_focal_loss (bool, default False) – If True, disable
torch.grad()
for asymmetric focal loss computing.reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.
implementation ({"alibaba-miil", "deep-psp"}, optional) – Implementation by Alibaba-MIIL, or by DeepPSP, case insensitive.
Note
Since
AsymmetricLoss
aims at emphasizing the contribution of positive samples, gamma_neg is usually greater than gamma_pos.TODO
Evaluate the settings that gamma_neg, gamma_pos are tensors, of shape
(1, n_classes)
, in which case we would have one ratio of positive to negative for each class.
References
- 1
Ridnik, Tal, et al. “Asymmetric Loss for Multi-Label Classification.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
- 2
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]¶
Forward pass.
- Parameters
input (torch.Tensor) – The predicted value tensor, of shape
(batch_size, n_classes)
.target (torch.Tensor) – The target tensor, of shape
(batch_size, n_classes)
.
- Returns
The asymmetric loss.
- Return type