ECG_CRNN

class torch_ecg.models.ECG_CRNN(classes: Sequence[str], n_leads: int, config: Optional[torch_ecg.cfg.CFG] = None, **kwargs: Any)[source]

Bases: torch.nn.modules.module.Module, torch_ecg.utils.utils_nn.CkptMixin, torch_ecg.utils.utils_nn.SizeMixin, torch_ecg.utils.misc.CitationMixin

Convolutional (Recurrent) Neural Network for ECG tasks.

This C(R)NN architecture is adapted from [Yao et al.1, Yao et al.2] in the first place,and then modified to be more general, and more flexible. The most famous model is perhaps [Hannun et al.3], which is a modified 1D-ResNet34 model. The website of this model is https://stanfordmlgroup.github.io/projects/ecg2/, and the code is hosted on https://github.com/awni/ecg.

The C(R)NN models have long been competitive in various ECG tasks, e.g. CPSC2018 entry 0236, CPSC2019 entry 0416. The models are also used in the PhysioNet/CinC Challenges.

Parameters
  • classes (List[str]) – List of the names of the classes.

  • n_leads (int) – Number of leads (number of input channels).

  • config (dict) – Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config files.

References

1

Qihang Yao, Xiaomao Fan, Yunpeng Cai, Ruxin Wang, Liyan Yin, and Ye Li. Time-Incremental Convolutional Neural Network for Arrhythmia Detection in Varied-Length Electrocardiogram. In 2018 IEEE 16th Intl Conf on Dependable, Autonomic and Secure Computing, 16th Intl Conf on Pervasive Intelligence and Computing, 4th Intl Conf on Big Data Intelligence and Computing and Cyber Science and Technology Congress (DASC/PiCom/DataCom/CyberSciTech), 754–761. Institute of Electrical and Electronics Engineers (IEEE), 8 2018. doi:10.1109/dasc/picom/datacom/cyberscitec.2018.00131.

2

Qihang Yao, Ruxin Wang, Xiaomao Fan, Jikui Liu, and Ye Li. Multi-Class Arrhythmia Detection from 12-Lead Varied-Length ECG using Attention-Based Time-Incremental Convolutional Neural Network. Information Fusion, 53:174–182, 1 2020. doi:10.1016/j.inffus.2019.06.024.

3

Awni Y Hannun, Pranav Rajpurkar, Masoumeh Haghpanahi, Geoffrey H Tison, Codie Bourn, Mintu P Turakhia, and Andrew Y Ng. Cardiologist-Level Arrhythmia Detection and Classification in Ambulatory Electrocardiograms using a Deep Neural Network. Nature Medicine, 25(1):65–69, 1 2019. doi:10.1038/s41591-018-0268-3.

compute_output_shape(seq_len: Optional[int] = None, batch_size: Optional[int] = None) Sequence[Optional[int]][source]

Compute the output shape of the model.

Parameters
  • seq_len (int, optional) – Length of the input signal tensor.

  • batch_size (int, optional) – Batch size of the input signal tensor.

Returns

output_shape – Output shape of the model.

Return type

sequence

extract_features(input: torch.Tensor) torch.Tensor[source]

Extract feature map before the dense (linear) classifying layer(s).

Parameters

input (torch.Tensor) – Input signal tensor, of shape (batch_size, channels, seq_len).

Returns

features – Feature map tensor, of shape (batch_size, channels, seq_len) or (batch_size, channels).

Return type

torch.Tensor

forward(input: torch.Tensor) torch.Tensor[source]

Forward pass of the model.

Parameters

input (torch.Tensor) – Input signal tensor, of shape (batch_size, channels, seq_len).

Returns

pred – Predictions tensor, of shape (batch_size, seq_len, channels) or (batch_size, channels).

Return type

torch.Tensor

classmethod from_v1(v1_ckpt: str, device: Optional[torch.device] = None) torch_ecg.models.ecg_crnn.ECG_CRNN[source]

Restore an instance of the model from a v1 checkpoint.

Parameters

v1_ckpt (str) – Path to the v1 checkpoint file.

Returns

model – The model instance restored from the v1 checkpoint.

Return type

ECG_CRNN

inference(input: Union[numpy.ndarray, torch.Tensor], class_names: bool = False, bin_pred_thr: float = 0.5) torch_ecg.components.outputs.BaseOutput[source]

Inference method for the model.

Parameters
  • input (numpy.ndarray or torch.Tensor) – Input tensor, of shape (batch_size, channels, seq_len).

  • class_names (bool, default False) – If True, the returned scalar predictions will be a DataFrame, with class names for each scalar prediction.

  • bin_pred_thr (float, default 0.5) – Threshold for making binary predictions from scalar predictions.

Returns

output

The output of the inference method, including the following items:

  • prob: numpy.ndarray or torch.Tensor, scalar predictions, (and binary predictions if class_names is True).

  • pred: numpy.ndarray or torch.Tensor, the array (with values 0, 1 for each class) of binary prediction.

Return type

BaseOutput