AsymmetricLoss

class torch_ecg.models.loss.AsymmetricLoss(gamma_neg: numbers.Real = 4, gamma_pos: numbers.Real = 1, prob_margin: float = 0.05, disable_torch_grad_focal_loss: bool = False, reduction: str = 'mean', implementation: str = 'alibaba-miil')[source]

Bases: torch.nn.modules.module.Module

Asymmetric loss class.

The asymmetric loss is proposed in 1, with official implementation in 2. The asymmetric loss is defined as

\[ASL = \begin{cases} L_+ := (1-p)^{\gamma_+} \log(p) \ L_- := (p_m)^{\gamma_-} \log(1-p_m) \end{cases}\]

where \(p_m = \max(p-m, 0)\) is the shifted probability, with probability margin \(m\). The loss on one label of one sample is

\[L = -yL_+ - (1-y)L_-\]
Parameters
  • gamma_neg (numbers.Real, default 4) – Exponent of the multiplier to the negative loss.

  • gamma_pos (numbers.Real, default 1) – Exponent of the multiplier to the positive loss.

  • prob_margin (float, default 0.05) – The probability margin

  • disable_torch_grad_focal_loss (bool, default False) – If True, disable torch.grad() for asymmetric focal loss computing.

  • reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.

  • implementation ({"alibaba-miil", "deep-psp"}, optional) – Implementation by Alibaba-MIIL, or by DeepPSP, case insensitive.

Note

Since AsymmetricLoss aims at emphasizing the contribution of positive samples, gamma_neg is usually greater than gamma_pos.

TODO

  1. Evaluate the settings that gamma_neg, gamma_pos are tensors, of shape (1, n_classes), in which case we would have one ratio of positive to negative for each class.

References

1

Ridnik, Tal, et al. “Asymmetric Loss for Multi-Label Classification.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.

2

https://github.com/Alibaba-MIIL/ASL/

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Forward pass.

Parameters
  • input (torch.Tensor) – The predicted value tensor, of shape (batch_size, n_classes).

  • target (torch.Tensor) – The target tensor, of shape (batch_size, n_classes).

Returns

The asymmetric loss.

Return type

torch.Tensor