ECG_CRNN¶
- class torch_ecg.models.ECG_CRNN(classes: Sequence[str], n_leads: int, config: Optional[torch_ecg.cfg.CFG] = None, **kwargs: Any)[source]¶
Bases:
torch.nn.modules.module.Module
,torch_ecg.utils.utils_nn.CkptMixin
,torch_ecg.utils.utils_nn.SizeMixin
,torch_ecg.utils.misc.CitationMixin
Convolutional (Recurrent) Neural Network for ECG tasks.
This C(R)NN architecture is adapted from [Yao et al.1, Yao et al.2] in the first place,and then modified to be more general, and more flexible. The most famous model is perhaps [Hannun et al.3], which is a modified 1D-ResNet34 model. The website of this model is https://stanfordmlgroup.github.io/projects/ecg2/, and the code is hosted on https://github.com/awni/ecg.
The C(R)NN models have long been competitive in various ECG tasks, e.g. CPSC2018 entry 0236, CPSC2019 entry 0416. The models are also used in the PhysioNet/CinC Challenges.
- Parameters
References
- 1
Qihang Yao, Xiaomao Fan, Yunpeng Cai, Ruxin Wang, Liyan Yin, and Ye Li. Time-Incremental Convolutional Neural Network for Arrhythmia Detection in Varied-Length Electrocardiogram. In 2018 IEEE 16th Intl Conf on Dependable, Autonomic and Secure Computing, 16th Intl Conf on Pervasive Intelligence and Computing, 4th Intl Conf on Big Data Intelligence and Computing and Cyber Science and Technology Congress (DASC/PiCom/DataCom/CyberSciTech), 754–761. Institute of Electrical and Electronics Engineers (IEEE), 8 2018. doi:10.1109/dasc/picom/datacom/cyberscitec.2018.00131.
- 2
Qihang Yao, Ruxin Wang, Xiaomao Fan, Jikui Liu, and Ye Li. Multi-Class Arrhythmia Detection from 12-Lead Varied-Length ECG using Attention-Based Time-Incremental Convolutional Neural Network. Information Fusion, 53:174–182, 1 2020. doi:10.1016/j.inffus.2019.06.024.
- 3
Awni Y Hannun, Pranav Rajpurkar, Masoumeh Haghpanahi, Geoffrey H Tison, Codie Bourn, Mintu P Turakhia, and Andrew Y Ng. Cardiologist-Level Arrhythmia Detection and Classification in Ambulatory Electrocardiograms using a Deep Neural Network. Nature Medicine, 25(1):65–69, 1 2019. doi:10.1038/s41591-018-0268-3.
- compute_output_shape(seq_len: Optional[int] = None, batch_size: Optional[int] = None) Sequence[Optional[int]] [source]¶
Compute the output shape of the model.
- extract_features(input: torch.Tensor) torch.Tensor [source]¶
Extract feature map before the dense (linear) classifying layer(s).
- Parameters
input (torch.Tensor) – Input signal tensor, of shape
(batch_size, channels, seq_len)
.- Returns
features – Feature map tensor, of shape
(batch_size, channels, seq_len)
or(batch_size, channels)
.- Return type
- forward(input: torch.Tensor) torch.Tensor [source]¶
Forward pass of the model.
- Parameters
input (torch.Tensor) – Input signal tensor, of shape
(batch_size, channels, seq_len)
.- Returns
pred – Predictions tensor, of shape
(batch_size, seq_len, channels)
or(batch_size, channels)
.- Return type
- classmethod from_v1(v1_ckpt: str, device: Optional[torch.device] = None) torch_ecg.models.ecg_crnn.ECG_CRNN [source]¶
Restore an instance of the model from a v1 checkpoint.
- inference(input: Union[numpy.ndarray, torch.Tensor], class_names: bool = False, bin_pred_thr: float = 0.5) torch_ecg.components.outputs.BaseOutput [source]¶
Inference method for the model.
- Parameters
input (numpy.ndarray or torch.Tensor) – Input tensor, of shape
(batch_size, channels, seq_len)
.class_names (bool, default False) – If True, the returned scalar predictions will be a
DataFrame
, with class names for each scalar prediction.bin_pred_thr (float, default 0.5) – Threshold for making binary predictions from scalar predictions.
- Returns
output –
The output of the inference method, including the following items:
prob: numpy.ndarray or torch.Tensor, scalar predictions, (and binary predictions if class_names is True).
pred: numpy.ndarray or torch.Tensor, the array (with values 0, 1 for each class) of binary prediction.
- Return type
BaseOutput