PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains
npx @tessl/cli install tessl/pypi-torchmetrics@1.8.0A comprehensive metrics library for PyTorch and PyTorch Lightning, providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other machine learning domains. TorchMetrics offers distributed and scalable metric computation with consistent APIs, automatic device handling, and seamless integration with PyTorch workflows.
pip install torchmetricsimport torchmetricsFor functional API:
import torchmetrics.functional as FFor specific metrics:
from torchmetrics import Accuracy, AUROC, F1Score
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
from torchmetrics.regression import MeanSquaredError, R2Scoreimport torch
import torchmetrics
# Initialize metrics
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
# Create sample predictions and targets
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(0, 3, (10,))
# Compute metrics
acc_score = accuracy(preds, target)
f1_score = f1(preds, target)
print(f"Accuracy: {acc_score:.4f}")
print(f"F1 Score: {f1_score:.4f}")
# Using functional API
from torchmetrics.functional import accuracy, f1_score as f1_func
acc_functional = accuracy(preds, target, task="multiclass", num_classes=3)
f1_functional = f1_func(preds, target, task="multiclass", num_classes=3)
print(f"Functional Accuracy: {acc_functional:.4f}")
print(f"Functional F1: {f1_functional:.4f}")TorchMetrics follows a dual-interface design pattern:
All metrics inherit from the base Metric class, ensuring consistent behavior, automatic device handling, state management, and distributed computation support across the entire library.
Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. Includes accuracy, precision, recall, F-scores, ROC/AUC, confusion matrices, and threshold-based metrics.
class Accuracy(Metric):
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
class AUROC(Metric):
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
class F1Score(Metric):
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
class Precision(Metric):
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
class Recall(Metric):
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
class ConfusionMatrix(Metric):
def __init__(self, task: str, num_classes: int, **kwargs): ...Metrics for regression tasks including error measurements, correlation coefficients, and explained variance measures for continuous target prediction evaluation.
class MeanSquaredError(Metric):
def __init__(self, **kwargs): ...
class MeanAbsoluteError(Metric):
def __init__(self, **kwargs): ...
class R2Score(Metric):
def __init__(self, num_outputs: int = 1, **kwargs): ...
class PearsonCorrCoef(Metric):
def __init__(self, num_outputs: int = 1, **kwargs): ...Specialized metrics for audio processing and speech evaluation including signal-to-noise ratios, perceptual quality measures, and separation metrics.
class ScaleInvariantSignalDistortionRatio(Metric):
def __init__(self, **kwargs): ...
class PermutationInvariantTraining(Metric):
def __init__(self, metric, mode: str = "speaker-wise", **kwargs): ...
class PerceptualEvaluationSpeechQuality(Metric):
def __init__(self, fs: int, mode: str = "wb", **kwargs): ...Image quality assessment metrics including structural similarity, peak signal-to-noise ratio, and perceptual quality measures for computer vision applications.
class StructuralSimilarityIndexMeasure(Metric):
def __init__(self, **kwargs): ...
class PeakSignalNoiseRatio(Metric):
def __init__(self, **kwargs): ...
class FrechetInceptionDistance(Metric):
def __init__(self, feature: int = 2048, **kwargs): ...Natural language processing metrics for translation, summarization, and text generation evaluation including BLEU, ROUGE, and semantic similarity measures.
class BLEUScore(Metric):
def __init__(self, n_gram: int = 4, **kwargs): ...
class ROUGEScore(Metric):
def __init__(self, rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"), **kwargs): ...
class BERTScore(Metric):
def __init__(self, model_name_or_path: str = "distilbert-base-uncased", **kwargs): ...Object detection and instance segmentation metrics for evaluating bounding box predictions, IoU calculations, and mean average precision.
class MeanAveragePrecision(Metric):
def __init__(self, **kwargs): ...
class IntersectionOverUnion(Metric):
def __init__(self, **kwargs): ...
class PanopticQuality(Metric):
def __init__(self, **kwargs): ...Unsupervised learning evaluation metrics including mutual information, Rand indices, and silhouette analysis for cluster quality assessment.
class AdjustedRandScore(Metric):
def __init__(self, **kwargs): ...
class NormalizedMutualInfoScore(Metric):
def __init__(self, average: str = "arithmetic", **kwargs): ...
class CalinskiHarabaszScore(Metric):
def __init__(self, **kwargs): ...Metrics for ranking and retrieval systems including precision at k, mean average precision, and normalized discounted cumulative gain.
class RetrievalMAP(Metric):
def __init__(self, **kwargs): ...
class RetrievalNormalizedDCG(Metric):
def __init__(self, k: int = None, **kwargs): ...
class RetrievalMRR(Metric):
def __init__(self, **kwargs): ...Semantic and instance segmentation evaluation including Dice coefficients, Intersection over Union, and Hausdorff distance for pixel-level predictions.
class DiceScore(Metric):
def __init__(self, **kwargs): ...
class MeanIoU(Metric):
def __init__(self, num_classes: int, **kwargs): ...
class HausdorffDistance(Metric):
def __init__(self, **kwargs): ...Metrics for evaluating multimodal AI systems including video-audio synchronization and cross-modal quality assessment.
class LipVertexError(Metric):
def __init__(self, **kwargs): ...
class CLIPScore(Metric):
def __init__(self, model_name_or_path: str = "openai/clip-vit-base-patch16", **kwargs): ...Statistical measures for analyzing associations and agreements between categorical variables.
class CramersV(Metric):
def __init__(self, num_classes: int, **kwargs): ...
class FleissKappa(Metric):
def __init__(self, mode: str = "counts", **kwargs): ...Metrics for analyzing geometric shapes and spatial configurations.
class ProcrustesDisparity(Metric):
def __init__(self, **kwargs): ...Specialized metrics for video quality assessment and evaluation.
class VideoMultiMethodAssessmentFusion(Metric):
def __init__(self, **kwargs): ...Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics and bootstrapping.
class MetricCollection:
def __init__(self, metrics: Union[Dict[str, Metric], List[Metric]], **kwargs): ...
class MeanMetric(Metric):
def __init__(self, **kwargs): ...
class SumMetric(Metric):
def __init__(self, **kwargs): ...
class BootStrapper:
def __init__(self, base_metric: Metric, num_bootstraps: int = 100, **kwargs): ...Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management.
def accuracy(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...
def f1_score(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...
def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: ...
def structural_similarity_index_measure(preds: Tensor, target: Tensor, **kwargs) -> Tensor: ...Core imports for TorchMetrics:
from typing import Union, Optional, Tuple, Dict, List, Any, Callable
import torch
from torch import TensorCommon type aliases:
TaskType = Union["binary", "multiclass", "multilabel"]
AverageType = Union["micro", "macro", "weighted", "none", None]
MDMCAverageType = Union["global", "samplewise"]
ThresholdType = Union[float, List[float], Tensor]Base metric class:
class Metric:
"""Base class for all metrics."""
def __init__(self, **kwargs): ...
def __call__(self, *args, **kwargs) -> Any: ...
def update(self, *args, **kwargs) -> None: ...
def compute(self) -> Any: ...
def reset(self) -> None: ...
def to(self, device: Union[str, torch.device]) -> "Metric": ...
def forward(self, *args, **kwargs) -> Any: ...