FocalLoss

class torch_ecg.models.loss.FocalLoss(gamma: float = 2.0, weight: Optional[torch.Tensor] = None, class_weight: Optional[torch.Tensor] = None, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean', multi_label: bool = True, **kwargs: Any)[source]

Bases: torch.nn.modules.loss._WeightedLoss

Focal loss class.

The focal loss is proposed in 1, and this implementation is based on 2, 3, and 4. The focal loss is computed as follows:

\[\operatorname{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \log(p_t)\]

Where:

  • \(p_t\) is the model’s estimated probability for each class.

Parameters
  • gamma (float, default 2.0) – The gamma parameter of focal loss.

  • weight (torch.Tensor, optional) – If multi_label is True, is a manual rescaling weight given to the loss of each batch element, of size batch_size; if multi_label is False, is a weight for each class, of size n_classes.

  • class_weight (torch.Tensor, optional) – The class weight, of shape (1, n_classes).

  • size_average (bool, optional) – Not used, to keep in accordance with PyTorch native loss.

  • reduce (bool, optional) – Not used, to keep in accordance with PyTorch native loss.

  • reduction ({"none", "mean", "sum"}, optional) – Specifies the reduction to apply to the output, by default “mean”.

  • multi_label (bool, default True) – If True, the loss is computed for multi-label classification.

References

1

Lin, Tsung-Yi, et al. “Focal loss for dense object detection.” Proceedings of the IEEE international conference on computer vision. 2017.

2

https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py

3

https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py

4

https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Forward pass.

Parameters
  • input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape (batch_size, n_classes).

  • target (torch.Tensor) – Multi-label binarized vector of shape (batch_size, n_classes), or single label binarized vector of shape (batch_size,).

Returns

The focal loss.

Return type

torch.Tensor