CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchmetrics

PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains

Overview
Eval results
Files

classification.mddocs/

Classification Metrics

Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. All classification metrics support automatic task detection and provide consistent APIs across different classification types with variants for each task type.

Capabilities

Accuracy Metrics

Measures the proportion of correct predictions among all predictions made.

class Accuracy(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class BinaryAccuracy(Metric):
    def __init__(
        self,
        threshold: float = 0.5,
        multidim_average: str = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MulticlassAccuracy(Metric):
    def __init__(
        self,
        num_classes: int,
        average: Optional[str] = "micro",
        top_k: Optional[int] = None,
        multidim_average: str = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelAccuracy(Metric):
    def __init__(
        self,
        num_labels: int,
        threshold: float = 0.5,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Area Under ROC Curve (AUROC)

Computes Area Under the Receiver Operating Characteristic Curve, measuring the model's ability to distinguish between classes.

class AUROC(Metric):
    def __init__(
        self,
        task: str,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "macro",
        max_fpr: Optional[float] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class BinaryAUROC(Metric):
    def __init__(
        self,
        max_fpr: Optional[float] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MulticlassAUROC(Metric):
    def __init__(
        self,
        num_classes: int,
        average: Optional[str] = "macro",
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelAUROC(Metric):
    def __init__(
        self,
        num_labels: int,
        average: Optional[str] = "macro",
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

ROC Curves

Computes Receiver Operating Characteristic curves for visualization and analysis.

class ROC(Metric):
    def __init__(
        self,
        task: str,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class BinaryROC(Metric):
    def __init__(
        self,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MulticlassROC(Metric):
    def __init__(
        self,
        num_classes: int,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelROC(Metric):
    def __init__(
        self,
        num_labels: int,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Precision and Recall

Measures the proportion of relevant instances among retrieved instances (precision) and retrieved instances among relevant instances (recall).

class Precision(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class Recall(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Each precision and recall metric also has Binary, Multiclass, and Multilabel variants with task-specific parameters.

F-Scores

Harmonic mean of precision and recall, with F1 being the most commonly used (beta=1).

class F1Score(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class FBetaScore(Metric):
    def __init__(
        self,
        task: str,
        beta: float = 1.0,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Average Precision

Computes average precision score, which summarizes a precision-recall curve as the weighted mean of precisions.

class AveragePrecision(Metric):
    def __init__(
        self,
        task: str,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "macro",
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Confusion Matrix

Computes confusion matrix for evaluating classification accuracy with detailed breakdown of true/false positives and negatives.

class ConfusionMatrix(Metric):
    def __init__(
        self,
        task: str,
        num_classes: int,
        threshold: float = 0.5,
        num_labels: Optional[int] = None,
        normalize: Optional[str] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Statistical Scores

Computes true positives, false positives, true negatives, false negatives, and support statistics.

class StatScores(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        multidim_average: str = "global",
        top_k: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Threshold-Based Metrics

Metrics that find optimal thresholds or evaluate performance at specific operating points.

class PrecisionAtFixedRecall(Metric):
    def __init__(
        self,
        task: str,
        min_recall: float,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class RecallAtFixedPrecision(Metric):
    def __init__(
        self,
        task: str,
        min_precision: float,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class SensitivityAtSpecificity(Metric):
    def __init__(
        self,
        task: str,
        min_specificity: float,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class SpecificityAtSensitivity(Metric):
    def __init__(
        self,
        task: str,
        min_sensitivity: float,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Advanced Classification Metrics

Specialized metrics for specific classification scenarios.

class CohenKappa(Metric):
    def __init__(
        self,
        task: str,
        num_classes: int,
        threshold: float = 0.5,
        num_labels: Optional[int] = None,
        weights: Optional[str] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MatthewsCorrCoef(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class JaccardIndex(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        average: Optional[str] = "micro",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class HammingDistance(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        multidim_average: str = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class ExactMatch(Metric):
    def __init__(
        self,
        task: str,
        threshold: float = 0.5,
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        multidim_average: str = "global",
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

Calibration and Ranking Metrics

Metrics for evaluating model calibration and ranking quality.

class CalibrationError(Metric):
    def __init__(
        self,
        task: str,
        n_bins: int = 15,
        norm: str = "l1",
        num_classes: Optional[int] = None,
        num_labels: Optional[int] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelRankingAveragePrecision(Metric):
    def __init__(
        self,
        num_labels: int,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelRankingLoss(Metric):
    def __init__(
        self,
        num_labels: int,
        validate_args: bool = True,
        **kwargs
    ): ...

class MultilabelCoverageError(Metric):
    def __init__(
        self,
        num_labels: int,
        validate_args: bool = True,
        **kwargs
    ): ...

Usage Examples

Basic Classification

import torch
from torchmetrics import Accuracy, F1Score, ConfusionMatrix

# Binary classification
binary_acc = Accuracy(task="binary")
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
target = torch.tensor([0, 1, 1, 0])
print(binary_acc(preds, target))

# Multiclass classification
multiclass_f1 = F1Score(task="multiclass", num_classes=3, average="macro")
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(0, 3, (10,))
print(multiclass_f1(preds, target))

# Multilabel classification
multilabel_cm = ConfusionMatrix(task="multilabel", num_labels=3)
preds = torch.randn(10, 3).sigmoid()
target = torch.randint(0, 2, (10, 3))
print(multilabel_cm(preds, target))

Threshold-based Metrics

from torchmetrics import PrecisionAtFixedRecall, ROC

# Find precision at 90% recall
precision_at_recall = PrecisionAtFixedRecall(task="binary", min_recall=0.9)
preds = torch.randn(100).sigmoid()
target = torch.randint(0, 2, (100,))
precision_value, threshold = precision_at_recall(preds, target)
print(f"Precision: {precision_value:.3f} at threshold: {threshold:.3f}")

# Compute ROC curve
roc = ROC(task="binary")
fpr, tpr, thresholds = roc(preds, target)

Types

TaskType = Union["binary", "multiclass", "multilabel"]
AverageType = Union["micro", "macro", "weighted", "none", None]
MDMCAverageType = Union["global", "samplewise"]
ThresholdType = Union[float, List[float], Tensor]

Install with Tessl CLI

npx tessl i tessl/pypi-torchmetrics

docs

audio.md

classification.md

clustering.md

detection.md

functional.md

image.md

index.md

multimodal.md

nominal.md

regression.md

retrieval.md

segmentation.md

shape.md

text.md

utilities.md

video.md

tile.json