FocalLoss¶
- class torch_ecg.models.loss.FocalLoss(gamma: float = 2.0, weight: Optional[torch.Tensor] = None, class_weight: Optional[torch.Tensor] = None, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', multi_label: bool = True, **kwargs: Any)[source]¶
Bases:
torch.nn.modules.loss._WeightedLoss
Focal loss class.
The focal loss is proposed in 1, and this implementation is based on 2, 3, and 4. The focal loss is computed as follows:
\[\operatorname{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \log(p_t)\]Where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters
gamma (float, default 2.0) – The gamma parameter of focal loss.
weight (torch.Tensor, optional) – If multi_label is True, is a manual rescaling weight given to the loss of each batch element, of size
batch_size
; if multi_label is False, is a weight for each class, of sizen_classes
.class_weight (torch.Tensor, optional) – The class weight, of shape
(1, n_classes)
.size_average (bool, optional) – Not used, to keep in accordance with PyTorch native loss.
reduce (bool, optional) – Not used, to keep in accordance with PyTorch native loss.
reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.
multi_label (bool, default True) – If True, the loss is computed for multi-label classification.
References
- 1
Lin, Tsung-Yi, et al. “Focal loss for dense object detection.” Proceedings of the IEEE international conference on computer vision. 2017.
- 2
https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
- 3
https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
- 4
https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327
- forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor [source]¶
Forward pass.
- Parameters
input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape
(batch_size, n_classes)
.target (torch.Tensor) – Multi-label binarized vector of shape
(batch_size, n_classes)
, or single label binarized vector of shape(batch_size,)
.
- Returns
The focal loss.
- Return type