RR_LSTM¶
- class torch_ecg.models.RR_LSTM(classes: Sequence[str], config: Optional[torch_ecg.cfg.CFG] = None, **kwargs: Any)[source]¶
Bases:
torch.nn.modules.module.Module
,torch_ecg.utils.utils_nn.CkptMixin
,torch_ecg.utils.utils_nn.SizeMixin
,torch_ecg.utils.misc.CitationMixin
LSTM model for RR time series classification or sequence labeling.
LSTM model using RR time series as input is studied in [Faust et al.1] for atrial fibrillation detection. It is further improved in [Wen et al.2] via incorporating attention mechanism and conditional random fields.
- Parameters
References
- 1
Oliver Faust, Alex Shenfield, Murtadha Kareem, Tan Ru San, Hamido Fujita, and U Rajendra Acharya. Automated Detection of Atrial Fibrillation using Long Short-Term Memory Network with RR Interval Signals. Computers in Biology and Medicine, 102:327–335, 11 2018. doi:10.1016/j.compbiomed.2018.07.001.
- 2
Hao Wen, Wenjian Yu, Yuanqing Wu, Shuai Yang, and Xiaolong Liu. A Scalable Hybrid Model for Atrial Fibrillation Detection. Journal of Mechanics in Medicine and Biology, 21(05):2140021, 4 2021. doi:10.1142/s0219519421400212.
- compute_output_shape(seq_len: Optional[int] = None, batch_size: Optional[int] = None) Sequence[Optional[int]] [source]¶
Compute the output shape of the model.
- forward(input: torch.Tensor) torch.Tensor [source]¶
Forward pass of the model.
- Parameters
input (torch.Tensor) – Input RR series tensor of shape
(seq_len, batch_size, n_channels)
, or(batch_size, n_channels, seq_len)
if config.batch_first is True.- Returns
Output tensor, of shape
(batch_size, seq_len, n_classes)
or(batch_size, n_classes)
.- Return type
- classmethod from_v1(v1_ckpt: str, device: Optional[torch.device] = None) torch_ecg.models.rr_lstm.RR_LSTM [source]¶
Restore an instance of the model from a v1 checkpoint.
- inference(input: torch.Tensor, bin_pred_thr: float = 0.5) torch_ecg.components.outputs.BaseOutput [source]¶
Inference method for the model.