CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-torchmetrics

PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains

Overview
Eval results
Files

utilities.mddocs/

Utilities and Aggregation

Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics, bootstrapping, and metric collections for comprehensive evaluation workflows.

Capabilities

Metric Collection

Container for organizing and computing multiple metrics simultaneously with automatic synchronization and state management.

class MetricCollection:
    def __init__(
        self,
        metrics: Union[Dict[str, Metric], List[Metric], Tuple[Metric, ...]],
        prefix: Optional[str] = None,
        postfix: Optional[str] = None,
        compute_groups: Union[bool, List[List[str]]] = True,
        **kwargs
    ): ...
    
    def __call__(self, *args, **kwargs) -> Dict[str, Any]: ...
    def update(self, *args, **kwargs) -> None: ...
    def compute(self) -> Dict[str, Any]: ...
    def reset(self) -> None: ...
    def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": ...

Aggregation Metrics

Basic metrics for accumulating and aggregating values across batches and distributed processes.

class MeanMetric(Metric):
    def __init__(
        self,
        nan_strategy: str = "warn",
        **kwargs
    ): ...

class SumMetric(Metric):
    def __init__(
        self,
        nan_strategy: str = "warn",
        **kwargs
    ): ...

class MaxMetric(Metric):
    def __init__(
        self,
        nan_strategy: str = "warn",
        **kwargs
    ): ...

class MinMetric(Metric):
    def __init__(
        self,
        nan_strategy: str = "warn",
        **kwargs
    ): ...

class CatMetric(Metric):
    def __init__(
        self,
        nan_strategy: str = "warn",
        **kwargs
    ): ...

Running Statistics

Metrics that maintain running statistics over streaming data without storing all values.

class RunningMean(Metric):
    def __init__(
        self,
        window: int = 100,
        **kwargs
    ): ...

class RunningSum(Metric):
    def __init__(
        self,
        window: int = 100,
        **kwargs
    ): ...

Metric Wrappers

Advanced wrappers that enhance metric functionality with additional capabilities.

class BootStrapper:
    def __init__(
        self,
        base_metric: Union[Metric, Callable],
        num_bootstraps: int = 100,
        mean: bool = True,
        std: bool = True,
        raw: bool = False,
        quantile: Optional[Union[float, Tensor]] = None,
        sampling_strategy: str = "poisson",
        **kwargs
    ): ...

class ClasswiseWrapper:
    def __init__(
        self,
        metric: Metric,
        labels: Optional[List[str]] = None,
        **kwargs
    ): ...

class MetricTracker:
    def __init__(
        self,
        metric: Metric,
        maximize: bool = True,
        **kwargs
    ): ...

class MinMaxMetric:
    def __init__(
        self,
        base_metric: Metric,
        **kwargs
    ): ...

class MultioutputWrapper:
    def __init__(
        self,
        metric: Metric,
        num_outputs: int,
        **kwargs
    ): ...

class MultitaskWrapper:
    def __init__(
        self,
        task_metrics: Dict[str, Metric],
        **kwargs
    ): ...

Advanced Wrappers

Specialized wrappers for complex metric computation scenarios.

class FeatureShare:
    def __init__(
        self,
        metric: Metric,
        reset_real_features: bool = True,
        **kwargs
    ): ...

class LambdaInputTransformer:
    def __init__(
        self,
        metric: Metric,
        transform_func: Callable,
        transform_labels: bool = True,
        **kwargs
    ): ...

class MetricInputTransformer:
    def __init__(
        self,
        metric: Metric,
        **kwargs
    ): ...

class Running:
    def __init__(
        self,
        base_metric: Metric,
        window_size: int = 100,
        **kwargs
    ): ...

class BinaryTargetTransformer:
    def __init__(
        self,
        metric: Metric,
        target_transform: Callable[[Tensor], Tensor],
        **kwargs
    ): ...

Usage Examples

Basic Metric Collection

import torch
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall

# Create a collection of classification metrics
metric_collection = MetricCollection({
    'accuracy': Accuracy(task="multiclass", num_classes=3),
    'f1': F1Score(task="multiclass", num_classes=3),
    'precision': Precision(task="multiclass", num_classes=3),
    'recall': Recall(task="multiclass", num_classes=3)
})

# Sample data
preds = torch.randn(32, 3).softmax(dim=-1)
target = torch.randint(0, 3, (32,))

# Compute all metrics at once
results = metric_collection(preds, target)
for metric_name, value in results.items():
    print(f"{metric_name}: {value:.4f}")

Aggregation Metrics

from torchmetrics import MeanMetric, SumMetric, MaxMetric

# Initialize aggregation metrics
mean_loss = MeanMetric()
total_samples = SumMetric()
max_confidence = MaxMetric()

# Accumulate values across batches
for batch_idx in range(10):
    batch_loss = torch.rand(1) * 2  # Random loss
    batch_size = torch.tensor(32.0)  # Batch size
    batch_max_conf = torch.rand(1)  # Max confidence in batch
    
    mean_loss.update(batch_loss)
    total_samples.update(batch_size)
    max_confidence.update(batch_max_conf)

# Get final aggregated values
print(f"Mean Loss: {mean_loss.compute():.4f}")
print(f"Total Samples: {total_samples.compute():.0f}")
print(f"Max Confidence: {max_confidence.compute():.4f}")

Bootstrapping for Confidence Intervals

from torchmetrics.wrappers import BootStrapper
from torchmetrics import Accuracy

# Bootstrap accuracy for confidence intervals
base_accuracy = Accuracy(task="binary")
bootstrap_accuracy = BootStrapper(
    base_accuracy, 
    num_bootstraps=1000,
    mean=True,
    std=True,
    quantile=torch.tensor([0.025, 0.975])  # 95% confidence interval
)

# Sample binary classification data
preds = torch.rand(100)
target = torch.randint(0, 2, (100,))

# Compute bootstrapped statistics
bootstrap_results = bootstrap_accuracy(preds, target)
print(f"Mean Accuracy: {bootstrap_results['mean']:.4f}")
print(f"Std Accuracy: {bootstrap_results['std']:.4f}")
print(f"95% Confidence Interval: [{bootstrap_results['quantile'][0]:.4f}, {bootstrap_results['quantile'][1]:.4f}]")

Per-Class Metrics

from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics import F1Score

# Compute F1 score per class
class_labels = ['cat', 'dog', 'bird']
base_f1 = F1Score(task="multiclass", num_classes=3, average=None)
classwise_f1 = ClasswiseWrapper(base_f1, labels=class_labels)

# Sample multiclass data
preds = torch.randn(100, 3).softmax(dim=-1)
target = torch.randint(0, 3, (100,))

# Get per-class results
classwise_results = classwise_f1(preds, target)
for class_name, f1_score in classwise_results.items():
    print(f"F1 Score for {class_name}: {f1_score:.4f}")

Metric Tracking

from torchmetrics.wrappers import MetricTracker
from torchmetrics import Accuracy

# Track best accuracy over time
tracker = MetricTracker(Accuracy(task="binary"), maximize=True)

# Simulate training epochs
accuracies = [0.6, 0.7, 0.65, 0.8, 0.75, 0.85, 0.82]

for epoch, acc in enumerate(accuracies):
    preds = torch.rand(100)
    target = torch.randint(0, 2, (100,))
    
    # Update tracker (automatically keeps best)
    result = tracker(preds, target)
    
print(f"Best Accuracy: {tracker.best_metric:.4f}")
print(f"Best Accuracy at Epoch: {tracker.best_step}")

Running Statistics

from torchmetrics import RunningMean

# Running mean with sliding window
running_mean = RunningMean(window=50)

# Simulate streaming data
for i in range(200):
    value = torch.tensor(float(i + torch.randn(1) * 0.1))
    running_mean.update(value)
    
    if i % 50 == 0:
        print(f"Step {i}: Running Mean = {running_mean.compute():.2f}")

Multi-output Wrapper

from torchmetrics.wrappers import MultioutputWrapper
from torchmetrics import MeanSquaredError

# MSE for multi-output regression
multi_mse = MultioutputWrapper(MeanSquaredError(), num_outputs=3)

# Multi-output predictions and targets
preds = torch.randn(50, 3)  # 50 samples, 3 outputs
target = torch.randn(50, 3)

# Compute MSE for each output
results = multi_mse(preds, target)
for i, mse in enumerate(results):
    print(f"Output {i+1} MSE: {mse:.4f}")

Multi-task Learning

from torchmetrics.wrappers import MultitaskWrapper
from torchmetrics import Accuracy, MeanSquaredError

# Metrics for multi-task learning
task_metrics = {
    'classification': Accuracy(task="multiclass", num_classes=5),
    'regression': MeanSquaredError()
}
multitask_metric = MultitaskWrapper(task_metrics)

# Sample multi-task predictions
task_preds = {
    'classification': torch.randn(32, 5).softmax(dim=-1),
    'regression': torch.randn(32, 1)
}
task_targets = {
    'classification': torch.randint(0, 5, (32,)),
    'regression': torch.randn(32, 1)
}

# Compute metrics for all tasks
task_results = multitask_metric(task_preds, task_targets)
for task, result in task_results.items():
    print(f"{task}: {result:.4f}")

Input Transformation

from torchmetrics.wrappers import LambdaInputTransformer
from torchmetrics import Accuracy

# Transform inputs before metric computation
def logits_to_probs(logits):
    return torch.softmax(logits, dim=-1)

# Wrap accuracy with input transformation
transformed_accuracy = LambdaInputTransformer(
    Accuracy(task="multiclass", num_classes=3),
    transform_func=logits_to_probs,
    transform_labels=False  # Don't transform targets
)

# Raw logits input
logits = torch.randn(32, 3)
target = torch.randint(0, 3, (32,))

# Accuracy automatically applies softmax
acc = transformed_accuracy(logits, target)
print(f"Accuracy: {acc:.4f}")

Types

from typing import Dict, List, Optional, Union, Callable, Any, Tuple
import torch
from torch import Tensor

MetricDict = Dict[str, Metric]
ComputeGroupsType = Union[bool, List[List[str]]]
NaNStrategy = Union["warn", "error", "ignore"]
SamplingStrategy = Union["poisson", "multinomial"]

Install with Tessl CLI

npx tessl i tessl/pypi-torchmetrics

docs

audio.md

classification.md

clustering.md

detection.md

functional.md

image.md

index.md

multimodal.md

nominal.md

regression.md

retrieval.md

segmentation.md

shape.md

text.md

utilities.md

video.md

tile.json