CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchmetrics

PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains

Overview
Eval results
Files

functional.mddocs/

Functional API

Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management. The functional API offers 350+ functions across all domains.

Overview

The functional API provides stateless versions of all TorchMetrics metrics. These functions compute metrics directly on input tensors without maintaining internal state, making them ideal for one-off computations and integration into custom training loops.

All functional implementations are available under torchmetrics.functional with domain-specific submodules mirroring the class-based organization.

Import Patterns

# General functional import
import torchmetrics.functional as F

# Domain-specific functional imports
import torchmetrics.functional.classification as FC
import torchmetrics.functional.regression as FR
import torchmetrics.functional.audio as FA
import torchmetrics.functional.image as FI
import torchmetrics.functional.text as FT

Capabilities

Classification Functions

Functional implementations of all classification metrics with support for binary, multiclass, and multilabel tasks.

def accuracy(
    preds: Tensor,
    target: Tensor,
    task: str,
    threshold: float = 0.5,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    average: Optional[str] = "micro",
    multidim_average: str = "global",
    top_k: Optional[int] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

def f1_score(
    preds: Tensor,
    target: Tensor,
    task: str,
    threshold: float = 0.5,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    average: Optional[str] = "micro",
    multidim_average: str = "global",
    top_k: Optional[int] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

def auroc(
    preds: Tensor,
    target: Tensor,
    task: str,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    average: Optional[str] = "macro",
    max_fpr: Optional[float] = None,
    thresholds: Optional[Union[int, List[float], Tensor]] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

def precision(
    preds: Tensor,
    target: Tensor,
    task: str,
    threshold: float = 0.5,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    average: Optional[str] = "micro",
    multidim_average: str = "global",
    top_k: Optional[int] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

def recall(
    preds: Tensor,
    target: Tensor,
    task: str,
    threshold: float = 0.5,
    num_classes: Optional[int] = None,
    num_labels: Optional[int] = None,
    average: Optional[str] = "micro",
    multidim_average: str = "global",
    top_k: Optional[int] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

def confusion_matrix(
    preds: Tensor,
    target: Tensor,
    task: str,
    num_classes: int,
    threshold: float = 0.5,
    num_labels: Optional[int] = None,
    normalize: Optional[str] = None,
    ignore_index: Optional[int] = None,
    validate_args: bool = True,
) -> Tensor: ...

Regression Functions

Functional implementations for regression metrics and correlation measures.

def mean_squared_error(
    preds: Tensor,
    target: Tensor,
    squared: bool = True,
    num_outputs: int = 1,
) -> Tensor: ...

def mean_absolute_error(
    preds: Tensor,
    target: Tensor,
    num_outputs: int = 1,
) -> Tensor: ...

def r2_score(
    preds: Tensor,
    target: Tensor,
    num_outputs: int = 1,
    multioutput: str = "uniform_average",
    adjusted: int = 0,
) -> Tensor: ...

def pearson_corrcoef(
    preds: Tensor,
    target: Tensor,
    num_outputs: int = 1,
) -> Tensor: ...

def spearman_corrcoef(
    preds: Tensor,
    target: Tensor,
    num_outputs: int = 1,
) -> Tensor: ...

def cosine_similarity(
    preds: Tensor,
    target: Tensor,
    reduction: str = "sum",
) -> Tensor: ...

Audio Functions

Functional audio quality and separation metrics.

def scale_invariant_signal_distortion_ratio(
    preds: Tensor,
    target: Tensor,
    zero_mean: bool = True,
) -> Tensor: ...

def signal_distortion_ratio(
    preds: Tensor,
    target: Tensor,
    use_cg_iter: Optional[int] = None,
    filter_length: int = 512,
    zero_mean: bool = True,
    load_diag: Optional[float] = None,
) -> Tensor: ...

def permutation_invariant_training(
    preds: Tensor,
    target: Tensor,
    metric: Callable,
    mode: str = "speaker-wise",
    eval_func: str = "max",
) -> Tensor: ...

def perceptual_evaluation_speech_quality(
    preds: Tensor,
    target: Tensor,
    fs: int,
    mode: str = "wb",
) -> Tensor: ...

Image Functions

Functional image quality assessment metrics.

def peak_signal_noise_ratio(
    preds: Tensor,
    target: Tensor,
    data_range: Optional[float] = None,
    base: float = 10.0,
    reduction: str = "elementwise_mean",
) -> Tensor: ...

def structural_similarity_index_measure(
    preds: Tensor,
    target: Tensor,
    gaussian_kernel: bool = True,
    sigma: Union[float, Tuple[float, float]] = 1.5,
    kernel_size: Union[int, Tuple[int, int]] = 11,
    reduction: str = "elementwise_mean",
    data_range: Optional[float] = None,
    k1: float = 0.01,
    k2: float = 0.03,
) -> Tensor: ...

def multiscale_structural_similarity_index_measure(
    preds: Tensor,
    target: Tensor,
    gaussian_kernel: bool = True,
    sigma: Union[float, Tuple[float, float]] = 1.5,
    kernel_size: Union[int, Tuple[int, int]] = 11,
    reduction: str = "elementwise_mean",
    data_range: Optional[float] = None,
    k1: float = 0.01,
    k2: float = 0.03,
    betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
    normalize: Optional[str] = "relu",
) -> Tensor: ...

def universal_image_quality_index(
    preds: Tensor,
    target: Tensor,
    kernel_size: Union[int, Tuple[int, int]] = 8,
    sigma: Union[float, Tuple[float, float]] = 1.5,
    reduction: str = "elementwise_mean",
) -> Tensor: ...

Text Functions

Functional NLP metrics for text evaluation.

def bleu_score(
    preds: Sequence[str],
    target: Sequence[Sequence[str]],
    n_gram: int = 4,
    smooth: bool = False,
    weights: Optional[Sequence[float]] = None,
) -> Tensor: ...

def rouge_score(
    preds: Union[str, Sequence[str]],
    target: Union[str, Sequence[str], Sequence[Sequence[str]]],
    rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),
    use_stemmer: bool = False,
    normalizer: Optional[Callable[[str], str]] = None,
    tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
    accumulate: str = "best",
) -> Dict[str, Tensor]: ...

def word_error_rate(
    preds: Union[str, List[str]],
    target: Union[str, List[str]],
) -> Tensor: ...

def char_error_rate(
    preds: Union[str, List[str]],
    target: Union[str, List[str]],
) -> Tensor: ...

def edit_distance(
    preds: Union[str, List[str]],
    target: Union[str, List[str]],
    substitution_cost: int = 1,
    reduction: Optional[str] = "mean",
) -> Tensor: ...

Clustering Functions

Functional clustering evaluation metrics.

def adjusted_rand_score(
    preds: Tensor,
    target: Tensor,
) -> Tensor: ...

def normalized_mutual_info_score(
    preds: Tensor,
    target: Tensor,
    average: str = "arithmetic",
) -> Tensor: ...

def calinski_harabasz_score(
    data: Tensor,
    labels: Tensor,
) -> Tensor: ...

def davies_bouldin_score(
    data: Tensor,
    labels: Tensor,
) -> Tensor: ...

Pairwise Functions

Functional pairwise distance and similarity measures.

def pairwise_cosine_similarity(
    x: Tensor,
    y: Optional[Tensor] = None,
    reduction: Optional[str] = None,
    zero_diagonal: bool = True,
) -> Tensor: ...

def pairwise_euclidean_distance(
    x: Tensor,
    y: Optional[Tensor] = None,
    reduction: Optional[str] = None,
    zero_diagonal: bool = True,
) -> Tensor: ...

def pairwise_manhattan_distance(
    x: Tensor,
    y: Optional[Tensor] = None,
    reduction: Optional[str] = None,
    zero_diagonal: bool = True,
) -> Tensor: ...

def pairwise_minkowski_distance(
    x: Tensor,
    y: Optional[Tensor] = None,
    p: float = 2.0,
    reduction: Optional[str] = None,
    zero_diagonal: bool = True,
) -> Tensor: ...

Usage Examples

Basic Functional Usage

import torch
import torchmetrics.functional as F

# Binary classification
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
target = torch.tensor([0, 1, 1, 0])

# Compute metrics directly
acc = F.accuracy(preds, target, task="binary")
f1 = F.f1_score(preds, target, task="binary")
auc = F.auroc(preds, target, task="binary")

print(f"Accuracy: {acc:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUROC: {auc:.4f}")

Multiclass Classification

import torchmetrics.functional.classification as FC

# Multiclass predictions
preds = torch.randn(100, 5).softmax(dim=-1)
target = torch.randint(0, 5, (100,))

# Compute various metrics
acc = FC.multiclass_accuracy(preds, target, num_classes=5)
precision = FC.multiclass_precision(preds, target, num_classes=5, average="macro")
recall = FC.multiclass_recall(preds, target, num_classes=5, average="macro")
conf_matrix = FC.multiclass_confusion_matrix(preds, target, num_classes=5)

print(f"Accuracy: {acc:.4f}")
print(f"Macro Precision: {precision:.4f}")
print(f"Macro Recall: {recall:.4f}")
print(f"Confusion Matrix Shape: {conf_matrix.shape}")

Regression Metrics

import torchmetrics.functional.regression as FR

# Regression predictions
preds = torch.randn(50, 1)
target = torch.randn(50, 1)

# Compute regression metrics
mse = FR.mean_squared_error(preds, target)
mae = FR.mean_absolute_error(preds, target)
r2 = FR.r2_score(preds, target)
pearson = FR.pearson_corrcoef(preds, target)

print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R²: {r2:.4f}")
print(f"Pearson Correlation: {pearson:.4f}")

Image Quality Assessment

import torchmetrics.functional.image as FI

# Image tensors
preds = torch.rand(4, 3, 256, 256)
target = torch.rand(4, 3, 256, 256)

# Compute image quality metrics
psnr = FI.peak_signal_noise_ratio(preds, target)
ssim = FI.structural_similarity_index_measure(preds, target)
ms_ssim = FI.multiscale_structural_similarity_index_measure(preds, target)

print(f"PSNR: {psnr:.4f}")
print(f"SSIM: {ssim:.4f}")
print(f"MS-SSIM: {ms_ssim:.4f}")

Text Evaluation

import torchmetrics.functional.text as FT

# Text evaluation
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", "a cat is on the mat"]]

# Compute text metrics
bleu = FT.bleu_score(preds, target)
rouge_scores = FT.rouge_score(preds[0], target[0])

print(f"BLEU Score: {bleu:.4f}")
print(f"ROUGE-1 F1: {rouge_scores['rouge1_fmeasure']:.4f}")
print(f"ROUGE-L F1: {rouge_scores['rougeL_fmeasure']:.4f}")

# Error rates
pred_text = ["this is a test"]
target_text = ["this is the test"]
wer = FT.word_error_rate(pred_text, target_text)
cer = FT.char_error_rate(pred_text, target_text)

print(f"Word Error Rate: {wer:.4f}")
print(f"Character Error Rate: {cer:.4f}")

Audio Quality

import torchmetrics.functional.audio as FA

# Audio signals
preds = torch.randn(4, 8000)  # 4 samples, 8000 time steps
target = torch.randn(4, 8000)

# Compute audio metrics
si_sdr = FA.scale_invariant_signal_distortion_ratio(preds, target)
si_snr = FA.scale_invariant_signal_noise_ratio(preds, target)

print(f"SI-SDR: {si_sdr:.4f} dB")
print(f"SI-SNR: {si_snr:.4f} dB")

Pairwise Distances

import torchmetrics.functional.pairwise as FP

# Feature vectors
x = torch.randn(100, 64)  # 100 samples, 64-dim features
y = torch.randn(50, 64)   # 50 samples, 64-dim features

# Compute pairwise similarities and distances
cosine_sim = FP.pairwise_cosine_similarity(x, y)
euclidean_dist = FP.pairwise_euclidean_distance(x, y)
manhattan_dist = FP.pairwise_manhattan_distance(x, y)

print(f"Cosine Similarity Shape: {cosine_sim.shape}")  # (100, 50)
print(f"Euclidean Distance Shape: {euclidean_dist.shape}")  # (100, 50)
print(f"Manhattan Distance Shape: {manhattan_dist.shape}")  # (100, 50)

Clustering Evaluation

import torchmetrics.functional.clustering as FCL

# Clustering results
pred_clusters = torch.randint(0, 3, (100,))
true_clusters = torch.randint(0, 3, (100,))

# Clustering metrics
ari = FCL.adjusted_rand_score(pred_clusters, true_clusters)
nmi = FCL.normalized_mutual_info_score(pred_clusters, true_clusters)

print(f"Adjusted Rand Index: {ari:.4f}")
print(f"Normalized Mutual Info: {nmi:.4f}")

# Internal clustering metrics (require data)
data = torch.randn(100, 10)  # 100 samples, 10 features
ch_score = FCL.calinski_harabasz_score(data, pred_clusters)
db_score = FCL.davies_bouldin_score(data, pred_clusters)

print(f"Calinski-Harabasz Score: {ch_score:.4f}")
print(f"Davies-Bouldin Score: {db_score:.4f}")

Functional vs Class-based API

When to Use Functional API

  • One-off metric computations
  • Custom training loops without Lightning
  • Minimal memory overhead requirements
  • Integration with existing codebases
  • Research experiments requiring flexibility

When to Use Class-based API

  • Accumulating metrics across batches
  • Distributed training scenarios
  • PyTorch Lightning integration
  • Automatic state management needed
  • Complex metric tracking workflows

Types

from typing import Union, Optional, List, Dict, Tuple, Sequence, Callable, Any
import torch
from torch import Tensor

# Common functional types
FunctionalOutput = Union[Tensor, Dict[str, Tensor], Tuple[Tensor, ...]]
TaskType = Union["binary", "multiclass", "multilabel"]
AverageType = Union["micro", "macro", "weighted", "none", None]
ReductionType = Union["mean", "sum", "none", "elementwise_mean"]

Install with Tessl CLI

npx tessl i tessl/pypi-torchmetrics

docs

audio.md

classification.md

clustering.md

detection.md

functional.md

image.md

index.md

multimodal.md

nominal.md

regression.md

retrieval.md

segmentation.md

shape.md

text.md

utilities.md

video.md

tile.json