CutMix¶
- class torch_ecg.augmenters.CutMix(fs: Optional[int] = None, num_mix: int = 1, alpha: numbers.Real = 0.5, beta: Optional[numbers.Real] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any)[source]¶
Bases:
torch_ecg.augmenters.base.Augmenter
CutMix augmentation.
CutMix is a data augmentation technique originally proposed in [Yun et al.1], with official implementation in clovaai/CutMix-PyTorch, and an unofficial implementation in ildoonet/cutmix.
This technique was designed for image classification tasks, but it can also be used for ECG tasks. This technique was very successful in CPSC2021 challenge of paroxysmal AF events detection.
- Parameters
fs (int, optional) – Sampling frequency, by default None.
num_mix (int, default 1) – Number of mixtures.
alpha (float, default 0.5) – Beta distribution parameter.
beta (float, optional) – Beta distribution parameter, by default equal to alpha.
prob (float, default 0.5) – Probability of applying this augmenter.
inplace (bool, default True) – Whether to perform this augmentation in-place.
**kwargs (dict, optional) – Additional keyword arguments.
Examples
cm = CutMix(prob=0.7) sig = torch.randn(32, 12, 5000) lb = torch.randint(0, 2, (32, 5000, 2), dtype=torch.float32) # 2 classes mask sig, lb = cm(sig, lb)
References
- 1
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 6022–6031. Institute of Electrical and Electronics Engineers (IEEE), 10 2019. doi:10.1109/ICCV.2019.00612.
- forward(sig: torch.Tensor, label: torch.Tensor, *extra_tensors: Sequence[torch.Tensor], **kwargs: Any) Tuple[torch.Tensor, ...] [source]¶
Forward method to perform CutMix augmentation.
- Parameters
sig (torch.Tensor) – Batched ECGs to be augmented, of shape
(batch, lead, siglen)
.label (torch.Tensor) – Class (one-hot) labels, of shape
(batch, num_classes)
; or segmentation masks, of shape(batch, siglen, num_classes)
.extra_tensors (Sequence[torch.Tensor], optional) – Other tensors to be augmented, by default None.
**kwargs (dict, optional) – Additional keyword arguments.
- Returns
Augmented tensors.
- Return type
Tuple[torch.Tensor]