PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains
Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics, bootstrapping, and metric collections for comprehensive evaluation workflows.
Container for organizing and computing multiple metrics simultaneously with automatic synchronization and state management.
class MetricCollection:
def __init__(
self,
metrics: Union[Dict[str, Metric], List[Metric], Tuple[Metric, ...]],
prefix: Optional[str] = None,
postfix: Optional[str] = None,
compute_groups: Union[bool, List[List[str]]] = True,
**kwargs
): ...
def __call__(self, *args, **kwargs) -> Dict[str, Any]: ...
def update(self, *args, **kwargs) -> None: ...
def compute(self) -> Dict[str, Any]: ...
def reset(self) -> None: ...
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": ...Basic metrics for accumulating and aggregating values across batches and distributed processes.
class MeanMetric(Metric):
def __init__(
self,
nan_strategy: str = "warn",
**kwargs
): ...
class SumMetric(Metric):
def __init__(
self,
nan_strategy: str = "warn",
**kwargs
): ...
class MaxMetric(Metric):
def __init__(
self,
nan_strategy: str = "warn",
**kwargs
): ...
class MinMetric(Metric):
def __init__(
self,
nan_strategy: str = "warn",
**kwargs
): ...
class CatMetric(Metric):
def __init__(
self,
nan_strategy: str = "warn",
**kwargs
): ...Metrics that maintain running statistics over streaming data without storing all values.
class RunningMean(Metric):
def __init__(
self,
window: int = 100,
**kwargs
): ...
class RunningSum(Metric):
def __init__(
self,
window: int = 100,
**kwargs
): ...Advanced wrappers that enhance metric functionality with additional capabilities.
class BootStrapper:
def __init__(
self,
base_metric: Union[Metric, Callable],
num_bootstraps: int = 100,
mean: bool = True,
std: bool = True,
raw: bool = False,
quantile: Optional[Union[float, Tensor]] = None,
sampling_strategy: str = "poisson",
**kwargs
): ...
class ClasswiseWrapper:
def __init__(
self,
metric: Metric,
labels: Optional[List[str]] = None,
**kwargs
): ...
class MetricTracker:
def __init__(
self,
metric: Metric,
maximize: bool = True,
**kwargs
): ...
class MinMaxMetric:
def __init__(
self,
base_metric: Metric,
**kwargs
): ...
class MultioutputWrapper:
def __init__(
self,
metric: Metric,
num_outputs: int,
**kwargs
): ...
class MultitaskWrapper:
def __init__(
self,
task_metrics: Dict[str, Metric],
**kwargs
): ...Specialized wrappers for complex metric computation scenarios.
class FeatureShare:
def __init__(
self,
metric: Metric,
reset_real_features: bool = True,
**kwargs
): ...
class LambdaInputTransformer:
def __init__(
self,
metric: Metric,
transform_func: Callable,
transform_labels: bool = True,
**kwargs
): ...
class MetricInputTransformer:
def __init__(
self,
metric: Metric,
**kwargs
): ...
class Running:
def __init__(
self,
base_metric: Metric,
window_size: int = 100,
**kwargs
): ...
class BinaryTargetTransformer:
def __init__(
self,
metric: Metric,
target_transform: Callable[[Tensor], Tensor],
**kwargs
): ...import torch
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall
# Create a collection of classification metrics
metric_collection = MetricCollection({
'accuracy': Accuracy(task="multiclass", num_classes=3),
'f1': F1Score(task="multiclass", num_classes=3),
'precision': Precision(task="multiclass", num_classes=3),
'recall': Recall(task="multiclass", num_classes=3)
})
# Sample data
preds = torch.randn(32, 3).softmax(dim=-1)
target = torch.randint(0, 3, (32,))
# Compute all metrics at once
results = metric_collection(preds, target)
for metric_name, value in results.items():
print(f"{metric_name}: {value:.4f}")from torchmetrics import MeanMetric, SumMetric, MaxMetric
# Initialize aggregation metrics
mean_loss = MeanMetric()
total_samples = SumMetric()
max_confidence = MaxMetric()
# Accumulate values across batches
for batch_idx in range(10):
batch_loss = torch.rand(1) * 2 # Random loss
batch_size = torch.tensor(32.0) # Batch size
batch_max_conf = torch.rand(1) # Max confidence in batch
mean_loss.update(batch_loss)
total_samples.update(batch_size)
max_confidence.update(batch_max_conf)
# Get final aggregated values
print(f"Mean Loss: {mean_loss.compute():.4f}")
print(f"Total Samples: {total_samples.compute():.0f}")
print(f"Max Confidence: {max_confidence.compute():.4f}")from torchmetrics.wrappers import BootStrapper
from torchmetrics import Accuracy
# Bootstrap accuracy for confidence intervals
base_accuracy = Accuracy(task="binary")
bootstrap_accuracy = BootStrapper(
base_accuracy,
num_bootstraps=1000,
mean=True,
std=True,
quantile=torch.tensor([0.025, 0.975]) # 95% confidence interval
)
# Sample binary classification data
preds = torch.rand(100)
target = torch.randint(0, 2, (100,))
# Compute bootstrapped statistics
bootstrap_results = bootstrap_accuracy(preds, target)
print(f"Mean Accuracy: {bootstrap_results['mean']:.4f}")
print(f"Std Accuracy: {bootstrap_results['std']:.4f}")
print(f"95% Confidence Interval: [{bootstrap_results['quantile'][0]:.4f}, {bootstrap_results['quantile'][1]:.4f}]")from torchmetrics.wrappers import ClasswiseWrapper
from torchmetrics import F1Score
# Compute F1 score per class
class_labels = ['cat', 'dog', 'bird']
base_f1 = F1Score(task="multiclass", num_classes=3, average=None)
classwise_f1 = ClasswiseWrapper(base_f1, labels=class_labels)
# Sample multiclass data
preds = torch.randn(100, 3).softmax(dim=-1)
target = torch.randint(0, 3, (100,))
# Get per-class results
classwise_results = classwise_f1(preds, target)
for class_name, f1_score in classwise_results.items():
print(f"F1 Score for {class_name}: {f1_score:.4f}")from torchmetrics.wrappers import MetricTracker
from torchmetrics import Accuracy
# Track best accuracy over time
tracker = MetricTracker(Accuracy(task="binary"), maximize=True)
# Simulate training epochs
accuracies = [0.6, 0.7, 0.65, 0.8, 0.75, 0.85, 0.82]
for epoch, acc in enumerate(accuracies):
preds = torch.rand(100)
target = torch.randint(0, 2, (100,))
# Update tracker (automatically keeps best)
result = tracker(preds, target)
print(f"Best Accuracy: {tracker.best_metric:.4f}")
print(f"Best Accuracy at Epoch: {tracker.best_step}")from torchmetrics import RunningMean
# Running mean with sliding window
running_mean = RunningMean(window=50)
# Simulate streaming data
for i in range(200):
value = torch.tensor(float(i + torch.randn(1) * 0.1))
running_mean.update(value)
if i % 50 == 0:
print(f"Step {i}: Running Mean = {running_mean.compute():.2f}")from torchmetrics.wrappers import MultioutputWrapper
from torchmetrics import MeanSquaredError
# MSE for multi-output regression
multi_mse = MultioutputWrapper(MeanSquaredError(), num_outputs=3)
# Multi-output predictions and targets
preds = torch.randn(50, 3) # 50 samples, 3 outputs
target = torch.randn(50, 3)
# Compute MSE for each output
results = multi_mse(preds, target)
for i, mse in enumerate(results):
print(f"Output {i+1} MSE: {mse:.4f}")from torchmetrics.wrappers import MultitaskWrapper
from torchmetrics import Accuracy, MeanSquaredError
# Metrics for multi-task learning
task_metrics = {
'classification': Accuracy(task="multiclass", num_classes=5),
'regression': MeanSquaredError()
}
multitask_metric = MultitaskWrapper(task_metrics)
# Sample multi-task predictions
task_preds = {
'classification': torch.randn(32, 5).softmax(dim=-1),
'regression': torch.randn(32, 1)
}
task_targets = {
'classification': torch.randint(0, 5, (32,)),
'regression': torch.randn(32, 1)
}
# Compute metrics for all tasks
task_results = multitask_metric(task_preds, task_targets)
for task, result in task_results.items():
print(f"{task}: {result:.4f}")from torchmetrics.wrappers import LambdaInputTransformer
from torchmetrics import Accuracy
# Transform inputs before metric computation
def logits_to_probs(logits):
return torch.softmax(logits, dim=-1)
# Wrap accuracy with input transformation
transformed_accuracy = LambdaInputTransformer(
Accuracy(task="multiclass", num_classes=3),
transform_func=logits_to_probs,
transform_labels=False # Don't transform targets
)
# Raw logits input
logits = torch.randn(32, 3)
target = torch.randint(0, 3, (32,))
# Accuracy automatically applies softmax
acc = transformed_accuracy(logits, target)
print(f"Accuracy: {acc:.4f}")from typing import Dict, List, Optional, Union, Callable, Any, Tuple
import torch
from torch import Tensor
MetricDict = Dict[str, Metric]
ComputeGroupsType = Union[bool, List[List[str]]]
NaNStrategy = Union["warn", "error", "ignore"]
SamplingStrategy = Union["poisson", "multinomial"]Install with Tessl CLI
npx tessl i tessl/pypi-torchmetrics