ECG_SEQ_LAB_NET

class torch_ecg.models.ECG_SEQ_LAB_NET(classes: Sequence[str], n_leads: int, config: Optional[torch_ecg.cfg.CFG] = None)[source]

Bases: torch_ecg.models.ecg_crnn.ECG_CRNN

SOTA model from CPSC2019 challenge.

Sequence labeling nets, for wave delineation, QRS complex detection, etc. Proposed in [Cai and Hu1].

pipeline

(multi-scopic, etc.) cnn –> head ((bidi-lstm –>) “attention” –> seq linear) -> output

Parameters
  • classes (List[str]) – List of the classes for sequence labeling.

  • n_leads (int) – Number of leads (number of input channels).

  • config (dict, optional) – Other hyper-parameters, including kernel sizes, etc. Refer to corresponding config file.

References

1

Wenjie Cai and Danqin Hu. QRS Complex Detection using Novel Deep Learning Neural Networks. IEEE Access, 8:97082–97089, 2020. doi:10.1109/access.2020.2997473.

compute_output_shape(seq_len: Optional[int] = None, batch_size: Optional[int] = None) Sequence[Optional[int]][source]

Compute the output shape of the model.

Parameters
  • seq_len (int, optional) – Length of the 1d input signal tensor.

  • batch_size (int, optional) – Batch size of the input signal tensor.

Returns

output_shape – The output shape of the model.

Return type

sequence

forward(input: torch.Tensor) torch.Tensor[source]

Forward pass.

Parameters

input (torch.Tensor) – Input tensor, of shape (batch_size, channels, seq_len).

Returns

pred – Output tensor, of shape (batch_size, seq_len, n_classes)

Return type

torch.Tensor

classmethod from_v1(v1_ckpt: str, device: Optional[torch.device] = None) torch_ecg.models.ecg_seq_lab_net.ECG_SEQ_LAB_NET[source]

Convert the v1 model to the current version.

Parameters

v1_ckpt (str) – Path to the v1 checkpoint file.

Returns

model – The converted model.

Return type

ECG_SEQ_LAB_NET