MaskedBCEWithLogitsLoss¶
- class torch_ecg.models.loss.MaskedBCEWithLogitsLoss[source]¶
Bases:
torch.nn.modules.loss.BCEWithLogitsLoss
Masked Binary Cross Entropy Loss class.
This loss is used mainly for the segmentation task, where there are some regions that are of much higher importance, for example, the onsets and offsets of some particular events (e.g. paroxysmal atrial fibrillation (AF) episodes).
This loss is proposed in 1, with a reference to the loss function used in the U-Net paper 2.
References
- 1
Wen, Hao, and Jingsu Kang. “A comparative study on neural networks for paroxysmal atrial fibrillation events detection from electrocardiography.” Journal of Electrocardiology 75 (2022): 19-27.
- 2
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. “U-net: Convolutional networks for biomedical image segmentation.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.
- forward(input: torch.Tensor, target: torch.Tensor, weight_mask: torch.Tensor) torch.Tensor [source]¶
Forward pass.
- Parameters
input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape
(batch_size, sig_len, n_classes)
.target (torch.Tensor) – The target tensor, of shape
(batch_size, sig_len, n_classes)
.weight_mask (torch.Tensor) – The weight mask tensor, of shape
(batch_size, sig_len, n_classes)
.
- Returns
The masked binary cross entropy loss.
- Return type
Note
input, target, and weight_mask should be 3-D tensors of the same shape.