AugmenterManager¶
- class torch_ecg.augmenters.AugmenterManager(*augs: Optional[Tuple[torch_ecg.augmenters.base.Augmenter, ...]], random: bool = False)[source]¶
Bases:
torch.nn.modules.module.Module
The
Module
to manage the augmenters.- Parameters
Examples
import torch from torch_ecg.cfg import CFG from torch_ecg.augmenters import AugmenterManager config = CFG( random=False, fs=500, baseline_wander={}, label_smooth={}, mixup={}, random_flip={}, random_masking={}, random_renormalize={}, stretch_compress={}, ) am = AugmenterManager.from_config(config) sig = torch.randn(32, 12, 5000) label = torch.randint(0, 2, (32, 26), dtype=torch.float32) mask1 = torch.randint(0, 2, (32, 5000, 3), dtype=torch.float32) mask2 = torch.randint(0, 3, (32, 5000), dtype=torch.long) sig, label, mask1, mask2 = am(sig, label, mask1, mask2)
- property augmenters: List[torch_ecg.augmenters.base.Augmenter]¶
The list of augmenters in the manager.
- forward(sig: torch.Tensor, label: Optional[torch.Tensor], *extra_tensors: Sequence[torch.Tensor], **kwargs: Any) Union[torch.Tensor, Tuple[torch.Tensor]] [source]¶
Forward the input ECGs through the augmenters.
- Parameters
sig (torch.Tensor) – Batched ECGs to be augmented, of shape
(batch, lead, siglen)
.label (torch.Tensor, optional) – Batched labels of the ECGs.
*extra_tensors (Sequence[torch.Tensor], optional) – Extra tensors to be augmented, e.g. masks for custom loss functions, etc.
**kwargs (dict, optional) – Additional keyword arguments to be passed to the augmenters.
- Returns
The augmented ECGs, labels, and optional extra tensors.
- Return type
Sequence[torch.Tensor]
- classmethod from_config(config: dict) torch_ecg.augmenters.augmenter_manager.AugmenterManager [source]¶
Create an
AugmenterManager
from a configuration.- Parameters
config (dict) – The configuration of the augmenters, better to be an
OrderedDict
.- Returns
am – A new instance of
AugmenterManager
.- Return type