Augmenter¶
- class torch_ecg.augmenters.Augmenter[source]¶
Bases:
torch_ecg.utils.misc.ReprMixin
,torch.nn.modules.module.Module
,abc.ABC
Base class for augmenters.
An Augmentor performs data augmentation on the input ECGs, labels, and optional extra tensors.
- abstract forward(sig: torch.Tensor, label: Optional[torch.Tensor] = None, *extra_tensors: Sequence[torch.Tensor], **kwargs: Any) Tuple[torch.Tensor, ...] [source]¶
Forward method of the augmenter.
- Parameters
sig (torch.Tensor) – Batched ECGs to be augmented, of shape
(batch, lead, siglen)
.label (torch.Tensor, optional) – Batched labels of the ECGs.
*extra_tensors (Sequence[torch.Tensor], optional) – Extra tensors to be augmented, e.g. masks for custom loss functions, etc.
**kwargs (dict, optional) – Additional keyword arguments to be passed to the augmenters.
- Returns
The augmented ECGs, labels, and optional extra tensors.
- Return type
Sequence[torch.Tensor]
- get_indices(prob: float, pop_size: int, scale_ratio: float = 0.1) List[int] [source]¶
Get a list of indices to be selected.
A random list of indices in the range
[0, pop_size-1]
is generated, with the probability of each index to be selected.- Parameters
- Returns
indices – A list of indices.
- Return type
List[int],
TODO
Add parameter min_dist so that any 2 selected indices are at least min_dist apart.