RR_LSTM

class torch_ecg.models.RR_LSTM(classes: Sequence[str], config: Optional[torch_ecg.cfg.CFG] = None, **kwargs: Any)[source]

Bases: torch.nn.modules.module.Module, torch_ecg.utils.utils_nn.CkptMixin, torch_ecg.utils.utils_nn.SizeMixin, torch_ecg.utils.misc.CitationMixin

LSTM model for RR time series classification or sequence labeling.

LSTM model using RR time series as input is studied in [Faust et al.1] for atrial fibrillation detection. It is further improved in [Wen et al.2] via incorporating attention mechanism and conditional random fields.

Parameters
  • classes (List[str]) – List of the names of the classes.

  • config (dict) – Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config file for details.

References

1

Oliver Faust, Alex Shenfield, Murtadha Kareem, Tan Ru San, Hamido Fujita, and U Rajendra Acharya. Automated Detection of Atrial Fibrillation using Long Short-Term Memory Network with RR Interval Signals. Computers in Biology and Medicine, 102:327–335, 11 2018. doi:10.1016/j.compbiomed.2018.07.001.

2

Hao Wen, Wenjian Yu, Yuanqing Wu, Shuai Yang, and Xiaolong Liu. A Scalable Hybrid Model for Atrial Fibrillation Detection. Journal of Mechanics in Medicine and Biology, 21(05):2140021, 4 2021. doi:10.1142/s0219519421400212.

compute_output_shape(seq_len: Optional[int] = None, batch_size: Optional[int] = None) Sequence[Optional[int]][source]

Compute the output shape of the model.

Parameters
  • seq_len (int, optional) – Length of the input series tensor.

  • batch_size (int, optional) – Batch size of the input series tensor.

Returns

output_shape – Output shape of the model.

Return type

sequence

forward(input: torch.Tensor) torch.Tensor[source]

Forward pass of the model.

Parameters

input (torch.Tensor) – Input RR series tensor of shape (seq_len, batch_size, n_channels), or (batch_size, n_channels, seq_len) if config.batch_first is True.

Returns

Output tensor, of shape (batch_size, seq_len, n_classes) or (batch_size, n_classes).

Return type

torch.Tensor

classmethod from_v1(v1_ckpt: str, device: Optional[torch.device] = None) torch_ecg.models.rr_lstm.RR_LSTM[source]

Restore an instance of the model from a v1 checkpoint.

Parameters

v1_ckpt (str) – Path to the v1 checkpoint file.

Returns

model – The model instance restored from the v1 checkpoint.

Return type

RR_LSTM

inference(input: torch.Tensor, bin_pred_thr: float = 0.5) torch_ecg.components.outputs.BaseOutput[source]

Inference method for the model.