CutMix

class torch_ecg.augmenters.CutMix(fs: Optional[int] = None, num_mix: int = 1, alpha: numbers.Real = 0.5, beta: Optional[numbers.Real] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any)[source]

Bases: torch_ecg.augmenters.base.Augmenter

CutMix augmentation.

CutMix is a data augmentation technique originally proposed in [Yun et al.1], with official implementation in clovaai/CutMix-PyTorch, and an unofficial implementation in ildoonet/cutmix.

This technique was designed for image classification tasks, but it can also be used for ECG tasks. This technique was very successful in CPSC2021 challenge of paroxysmal AF events detection.

Parameters
  • fs (int, optional) – Sampling frequency, by default None.

  • num_mix (int, default 1) – Number of mixtures.

  • alpha (float, default 0.5) – Beta distribution parameter.

  • beta (float, optional) – Beta distribution parameter, by default equal to alpha.

  • prob (float, default 0.5) – Probability of applying this augmenter.

  • inplace (bool, default True) – Whether to perform this augmentation in-place.

  • **kwargs (dict, optional) – Additional keyword arguments.

Examples

cm = CutMix(prob=0.7)
sig = torch.randn(32, 12, 5000)
lb = torch.randint(0, 2, (32, 5000, 2), dtype=torch.float32)  # 2 classes mask
sig, lb = cm(sig, lb)

References

1

Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 6022–6031. Institute of Electrical and Electronics Engineers (IEEE), 10 2019. doi:10.1109/ICCV.2019.00612.

extra_repr_keys() List[str][source]

Extra keys for __repr__() and __str__().

forward(sig: torch.Tensor, label: torch.Tensor, *extra_tensors: Sequence[torch.Tensor], **kwargs: Any) Tuple[torch.Tensor, ...][source]

Forward method to perform CutMix augmentation.

Parameters
  • sig (torch.Tensor) – Batched ECGs to be augmented, of shape (batch, lead, siglen).

  • label (torch.Tensor) – Class (one-hot) labels, of shape (batch, num_classes); or segmentation masks, of shape (batch, siglen, num_classes).

  • extra_tensors (Sequence[torch.Tensor], optional) – Other tensors to be augmented, by default None.

  • **kwargs (dict, optional) – Additional keyword arguments.

Returns

Augmented tensors.

Return type

Tuple[torch.Tensor]