BCEWithLogitsWithClassWeightLoss

class torch_ecg.models.loss.BCEWithLogitsWithClassWeightLoss(class_weight: torch.Tensor)[source]

Bases: torch.nn.modules.loss.BCEWithLogitsLoss

Class-weighted Binary Cross Entropy Loss class.

Parameters

class_weight (torch.Tensor) – Class weight, of shape (1, n_classes).

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]

Forward pass.

Parameters
  • input (torch.Tensor) – The predicted value tensor (before sigmoid), of shape (batch_size, ..., n_classes).

  • target (torch.Tensor) – The target tensor, of shape (batch_size, ..., n_classes).

Returns

The class-weighted binary cross entropy loss.

Return type

torch.Tensor