torch_ecg.utils.metrics_from_confusion_matrix¶
- torch_ecg.utils.metrics_from_confusion_matrix(labels: Union[numpy.ndarray, torch.Tensor], outputs: Union[numpy.ndarray, torch.Tensor], num_classes: Optional[int] = None, weights: Optional[Union[numpy.ndarray, torch.Tensor]] = None, thr: float = 0.5) Dict[str, Union[float, numpy.ndarray]] [source]¶
Compute macro metrics, and metrics for each class.
- Parameters
labels (numpy.ndarray or torch.Tensor) – Binary labels, of shape
(n_samples, n_classes)
, or indices of each label class, of shape(n_samples,)
.outputs (numpy.ndarray or torch.Tensor) – Probability outputs, of shape
(n_samples, n_classes)
, or binary outputs, of shape(n_samples, n_classes)
, or indices of each class predicted, of shape(n_samples,)
.num_classes (int, optional) – Number of classes. If labels and outputs are both of shape
(n_samples,)
, then num_classes must be specified.weights (numpy.ndarray or torch.Tensor, optional) – Weights for each class, of shape
(n_classes,)
, used to compute macro metrics.thr (float, default: 0.5) – Threshold for binary classification, valid only if outputs is of shape
(n_samples, n_classes)
.
- Returns
metrics – Metrics computed from the one-vs-rest confusion matrix.
- Return type
Examples
>>> from torch_ecg.cfg import DEFAULTS >>> # binary labels (100 samples, 10 classes, multi-label) >>> labels = DEFAULTS.RNG_randint(0, 1, (100, 10)) >>> # probability outputs (100 samples, 10 classes, multi-label) >>> outputs = DEFAULTS.RNG.random((100, 10)) >>> metrics = metrics_from_confusion_matrix(labels, outputs) >>> # binarize outputs (100 samples, 10 classes, multi-label) >>> outputs = DEFAULTS.RNG_randint(0, 1, (100, 10)) >>> # would raise >>> # RuntimeWarning: `outputs` is probably binary, AUC may be incorrect >>> metrics = metrics_from_confusion_matrix(labels, outputs) >>> # categorical outputs (100 samples, 10 classes) >>> outputs = DEFAULTS.RNG_randint(0, 9, (100,)) >>> # would raise >>> # RuntimeWarning: `outputs` is probably binary, AUC may be incorrect >>> metrics = metrics_from_confusion_matrix(labels, outputs)