PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains
Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management. The functional API offers 350+ functions across all domains.
The functional API provides stateless versions of all TorchMetrics metrics. These functions compute metrics directly on input tensors without maintaining internal state, making them ideal for one-off computations and integration into custom training loops.
All functional implementations are available under torchmetrics.functional with domain-specific submodules mirroring the class-based organization.
# General functional import
import torchmetrics.functional as F
# Domain-specific functional imports
import torchmetrics.functional.classification as FC
import torchmetrics.functional.regression as FR
import torchmetrics.functional.audio as FA
import torchmetrics.functional.image as FI
import torchmetrics.functional.text as FTFunctional implementations of all classification metrics with support for binary, multiclass, and multilabel tasks.
def accuracy(
preds: Tensor,
target: Tensor,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...
def f1_score(
preds: Tensor,
target: Tensor,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...
def auroc(
preds: Tensor,
target: Tensor,
task: str,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...
def precision(
preds: Tensor,
target: Tensor,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...
def recall(
preds: Tensor,
target: Tensor,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...
def confusion_matrix(
preds: Tensor,
target: Tensor,
task: str,
num_classes: int,
threshold: float = 0.5,
num_labels: Optional[int] = None,
normalize: Optional[str] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Tensor: ...Functional implementations for regression metrics and correlation measures.
def mean_squared_error(
preds: Tensor,
target: Tensor,
squared: bool = True,
num_outputs: int = 1,
) -> Tensor: ...
def mean_absolute_error(
preds: Tensor,
target: Tensor,
num_outputs: int = 1,
) -> Tensor: ...
def r2_score(
preds: Tensor,
target: Tensor,
num_outputs: int = 1,
multioutput: str = "uniform_average",
adjusted: int = 0,
) -> Tensor: ...
def pearson_corrcoef(
preds: Tensor,
target: Tensor,
num_outputs: int = 1,
) -> Tensor: ...
def spearman_corrcoef(
preds: Tensor,
target: Tensor,
num_outputs: int = 1,
) -> Tensor: ...
def cosine_similarity(
preds: Tensor,
target: Tensor,
reduction: str = "sum",
) -> Tensor: ...Functional audio quality and separation metrics.
def scale_invariant_signal_distortion_ratio(
preds: Tensor,
target: Tensor,
zero_mean: bool = True,
) -> Tensor: ...
def signal_distortion_ratio(
preds: Tensor,
target: Tensor,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = True,
load_diag: Optional[float] = None,
) -> Tensor: ...
def permutation_invariant_training(
preds: Tensor,
target: Tensor,
metric: Callable,
mode: str = "speaker-wise",
eval_func: str = "max",
) -> Tensor: ...
def perceptual_evaluation_speech_quality(
preds: Tensor,
target: Tensor,
fs: int,
mode: str = "wb",
) -> Tensor: ...Functional image quality assessment metrics.
def peak_signal_noise_ratio(
preds: Tensor,
target: Tensor,
data_range: Optional[float] = None,
base: float = 10.0,
reduction: str = "elementwise_mean",
) -> Tensor: ...
def structural_similarity_index_measure(
preds: Tensor,
target: Tensor,
gaussian_kernel: bool = True,
sigma: Union[float, Tuple[float, float]] = 1.5,
kernel_size: Union[int, Tuple[int, int]] = 11,
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
) -> Tensor: ...
def multiscale_structural_similarity_index_measure(
preds: Tensor,
target: Tensor,
gaussian_kernel: bool = True,
sigma: Union[float, Tuple[float, float]] = 1.5,
kernel_size: Union[int, Tuple[int, int]] = 11,
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Optional[str] = "relu",
) -> Tensor: ...
def universal_image_quality_index(
preds: Tensor,
target: Tensor,
kernel_size: Union[int, Tuple[int, int]] = 8,
sigma: Union[float, Tuple[float, float]] = 1.5,
reduction: str = "elementwise_mean",
) -> Tensor: ...Functional NLP metrics for text evaluation.
def bleu_score(
preds: Sequence[str],
target: Sequence[Sequence[str]],
n_gram: int = 4,
smooth: bool = False,
weights: Optional[Sequence[float]] = None,
) -> Tensor: ...
def rouge_score(
preds: Union[str, Sequence[str]],
target: Union[str, Sequence[str], Sequence[Sequence[str]]],
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),
use_stemmer: bool = False,
normalizer: Optional[Callable[[str], str]] = None,
tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
accumulate: str = "best",
) -> Dict[str, Tensor]: ...
def word_error_rate(
preds: Union[str, List[str]],
target: Union[str, List[str]],
) -> Tensor: ...
def char_error_rate(
preds: Union[str, List[str]],
target: Union[str, List[str]],
) -> Tensor: ...
def edit_distance(
preds: Union[str, List[str]],
target: Union[str, List[str]],
substitution_cost: int = 1,
reduction: Optional[str] = "mean",
) -> Tensor: ...Functional clustering evaluation metrics.
def adjusted_rand_score(
preds: Tensor,
target: Tensor,
) -> Tensor: ...
def normalized_mutual_info_score(
preds: Tensor,
target: Tensor,
average: str = "arithmetic",
) -> Tensor: ...
def calinski_harabasz_score(
data: Tensor,
labels: Tensor,
) -> Tensor: ...
def davies_bouldin_score(
data: Tensor,
labels: Tensor,
) -> Tensor: ...Functional pairwise distance and similarity measures.
def pairwise_cosine_similarity(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Optional[str] = None,
zero_diagonal: bool = True,
) -> Tensor: ...
def pairwise_euclidean_distance(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Optional[str] = None,
zero_diagonal: bool = True,
) -> Tensor: ...
def pairwise_manhattan_distance(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Optional[str] = None,
zero_diagonal: bool = True,
) -> Tensor: ...
def pairwise_minkowski_distance(
x: Tensor,
y: Optional[Tensor] = None,
p: float = 2.0,
reduction: Optional[str] = None,
zero_diagonal: bool = True,
) -> Tensor: ...import torch
import torchmetrics.functional as F
# Binary classification
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
target = torch.tensor([0, 1, 1, 0])
# Compute metrics directly
acc = F.accuracy(preds, target, task="binary")
f1 = F.f1_score(preds, target, task="binary")
auc = F.auroc(preds, target, task="binary")
print(f"Accuracy: {acc:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUROC: {auc:.4f}")import torchmetrics.functional.classification as FC
# Multiclass predictions
preds = torch.randn(100, 5).softmax(dim=-1)
target = torch.randint(0, 5, (100,))
# Compute various metrics
acc = FC.multiclass_accuracy(preds, target, num_classes=5)
precision = FC.multiclass_precision(preds, target, num_classes=5, average="macro")
recall = FC.multiclass_recall(preds, target, num_classes=5, average="macro")
conf_matrix = FC.multiclass_confusion_matrix(preds, target, num_classes=5)
print(f"Accuracy: {acc:.4f}")
print(f"Macro Precision: {precision:.4f}")
print(f"Macro Recall: {recall:.4f}")
print(f"Confusion Matrix Shape: {conf_matrix.shape}")import torchmetrics.functional.regression as FR
# Regression predictions
preds = torch.randn(50, 1)
target = torch.randn(50, 1)
# Compute regression metrics
mse = FR.mean_squared_error(preds, target)
mae = FR.mean_absolute_error(preds, target)
r2 = FR.r2_score(preds, target)
pearson = FR.pearson_corrcoef(preds, target)
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R²: {r2:.4f}")
print(f"Pearson Correlation: {pearson:.4f}")import torchmetrics.functional.image as FI
# Image tensors
preds = torch.rand(4, 3, 256, 256)
target = torch.rand(4, 3, 256, 256)
# Compute image quality metrics
psnr = FI.peak_signal_noise_ratio(preds, target)
ssim = FI.structural_similarity_index_measure(preds, target)
ms_ssim = FI.multiscale_structural_similarity_index_measure(preds, target)
print(f"PSNR: {psnr:.4f}")
print(f"SSIM: {ssim:.4f}")
print(f"MS-SSIM: {ms_ssim:.4f}")import torchmetrics.functional.text as FT
# Text evaluation
preds = ["the cat is on the mat"]
target = [["there is a cat on the mat", "a cat is on the mat"]]
# Compute text metrics
bleu = FT.bleu_score(preds, target)
rouge_scores = FT.rouge_score(preds[0], target[0])
print(f"BLEU Score: {bleu:.4f}")
print(f"ROUGE-1 F1: {rouge_scores['rouge1_fmeasure']:.4f}")
print(f"ROUGE-L F1: {rouge_scores['rougeL_fmeasure']:.4f}")
# Error rates
pred_text = ["this is a test"]
target_text = ["this is the test"]
wer = FT.word_error_rate(pred_text, target_text)
cer = FT.char_error_rate(pred_text, target_text)
print(f"Word Error Rate: {wer:.4f}")
print(f"Character Error Rate: {cer:.4f}")import torchmetrics.functional.audio as FA
# Audio signals
preds = torch.randn(4, 8000) # 4 samples, 8000 time steps
target = torch.randn(4, 8000)
# Compute audio metrics
si_sdr = FA.scale_invariant_signal_distortion_ratio(preds, target)
si_snr = FA.scale_invariant_signal_noise_ratio(preds, target)
print(f"SI-SDR: {si_sdr:.4f} dB")
print(f"SI-SNR: {si_snr:.4f} dB")import torchmetrics.functional.pairwise as FP
# Feature vectors
x = torch.randn(100, 64) # 100 samples, 64-dim features
y = torch.randn(50, 64) # 50 samples, 64-dim features
# Compute pairwise similarities and distances
cosine_sim = FP.pairwise_cosine_similarity(x, y)
euclidean_dist = FP.pairwise_euclidean_distance(x, y)
manhattan_dist = FP.pairwise_manhattan_distance(x, y)
print(f"Cosine Similarity Shape: {cosine_sim.shape}") # (100, 50)
print(f"Euclidean Distance Shape: {euclidean_dist.shape}") # (100, 50)
print(f"Manhattan Distance Shape: {manhattan_dist.shape}") # (100, 50)import torchmetrics.functional.clustering as FCL
# Clustering results
pred_clusters = torch.randint(0, 3, (100,))
true_clusters = torch.randint(0, 3, (100,))
# Clustering metrics
ari = FCL.adjusted_rand_score(pred_clusters, true_clusters)
nmi = FCL.normalized_mutual_info_score(pred_clusters, true_clusters)
print(f"Adjusted Rand Index: {ari:.4f}")
print(f"Normalized Mutual Info: {nmi:.4f}")
# Internal clustering metrics (require data)
data = torch.randn(100, 10) # 100 samples, 10 features
ch_score = FCL.calinski_harabasz_score(data, pred_clusters)
db_score = FCL.davies_bouldin_score(data, pred_clusters)
print(f"Calinski-Harabasz Score: {ch_score:.4f}")
print(f"Davies-Bouldin Score: {db_score:.4f}")from typing import Union, Optional, List, Dict, Tuple, Sequence, Callable, Any
import torch
from torch import Tensor
# Common functional types
FunctionalOutput = Union[Tensor, Dict[str, Tensor], Tuple[Tensor, ...]]
TaskType = Union["binary", "multiclass", "multilabel"]
AverageType = Union["micro", "macro", "weighted", "none", None]
ReductionType = Union["mean", "sum", "none", "elementwise_mean"]Install with Tessl CLI
npx tessl i tessl/pypi-torchmetrics