PyTorch native metrics library providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other ML domains
Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. All classification metrics support automatic task detection and provide consistent APIs across different classification types with variants for each task type.
Measures the proportion of correct predictions among all predictions made.
class Accuracy(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class BinaryAccuracy(Metric):
def __init__(
self,
threshold: float = 0.5,
multidim_average: str = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MulticlassAccuracy(Metric):
def __init__(
self,
num_classes: int,
average: Optional[str] = "micro",
top_k: Optional[int] = None,
multidim_average: str = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MultilabelAccuracy(Metric):
def __init__(
self,
num_labels: int,
threshold: float = 0.5,
average: Optional[str] = "micro",
multidim_average: str = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Computes Area Under the Receiver Operating Characteristic Curve, measuring the model's ability to distinguish between classes.
class AUROC(Metric):
def __init__(
self,
task: str,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "macro",
max_fpr: Optional[float] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class BinaryAUROC(Metric):
def __init__(
self,
max_fpr: Optional[float] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MulticlassAUROC(Metric):
def __init__(
self,
num_classes: int,
average: Optional[str] = "macro",
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MultilabelAUROC(Metric):
def __init__(
self,
num_labels: int,
average: Optional[str] = "macro",
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Computes Receiver Operating Characteristic curves for visualization and analysis.
class ROC(Metric):
def __init__(
self,
task: str,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class BinaryROC(Metric):
def __init__(
self,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MulticlassROC(Metric):
def __init__(
self,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MultilabelROC(Metric):
def __init__(
self,
num_labels: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Measures the proportion of relevant instances among retrieved instances (precision) and retrieved instances among relevant instances (recall).
class Precision(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class Recall(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Each precision and recall metric also has Binary, Multiclass, and Multilabel variants with task-specific parameters.
Harmonic mean of precision and recall, with F1 being the most commonly used (beta=1).
class F1Score(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class FBetaScore(Metric):
def __init__(
self,
task: str,
beta: float = 1.0,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Computes average precision score, which summarizes a precision-recall curve as the weighted mean of precisions.
class AveragePrecision(Metric):
def __init__(
self,
task: str,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "macro",
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Computes confusion matrix for evaluating classification accuracy with detailed breakdown of true/false positives and negatives.
class ConfusionMatrix(Metric):
def __init__(
self,
task: str,
num_classes: int,
threshold: float = 0.5,
num_labels: Optional[int] = None,
normalize: Optional[str] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Computes true positives, false positives, true negatives, false negatives, and support statistics.
class StatScores(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
multidim_average: str = "global",
top_k: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Metrics that find optimal thresholds or evaluate performance at specific operating points.
class PrecisionAtFixedRecall(Metric):
def __init__(
self,
task: str,
min_recall: float,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class RecallAtFixedPrecision(Metric):
def __init__(
self,
task: str,
min_precision: float,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class SensitivityAtSpecificity(Metric):
def __init__(
self,
task: str,
min_specificity: float,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class SpecificityAtSensitivity(Metric):
def __init__(
self,
task: str,
min_sensitivity: float,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Specialized metrics for specific classification scenarios.
class CohenKappa(Metric):
def __init__(
self,
task: str,
num_classes: int,
threshold: float = 0.5,
num_labels: Optional[int] = None,
weights: Optional[str] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MatthewsCorrCoef(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class JaccardIndex(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[str] = "micro",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class HammingDistance(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
multidim_average: str = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class ExactMatch(Metric):
def __init__(
self,
task: str,
threshold: float = 0.5,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
multidim_average: str = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...Metrics for evaluating model calibration and ranking quality.
class CalibrationError(Metric):
def __init__(
self,
task: str,
n_bins: int = 15,
norm: str = "l1",
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs
): ...
class MultilabelRankingAveragePrecision(Metric):
def __init__(
self,
num_labels: int,
validate_args: bool = True,
**kwargs
): ...
class MultilabelRankingLoss(Metric):
def __init__(
self,
num_labels: int,
validate_args: bool = True,
**kwargs
): ...
class MultilabelCoverageError(Metric):
def __init__(
self,
num_labels: int,
validate_args: bool = True,
**kwargs
): ...import torch
from torchmetrics import Accuracy, F1Score, ConfusionMatrix
# Binary classification
binary_acc = Accuracy(task="binary")
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
target = torch.tensor([0, 1, 1, 0])
print(binary_acc(preds, target))
# Multiclass classification
multiclass_f1 = F1Score(task="multiclass", num_classes=3, average="macro")
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(0, 3, (10,))
print(multiclass_f1(preds, target))
# Multilabel classification
multilabel_cm = ConfusionMatrix(task="multilabel", num_labels=3)
preds = torch.randn(10, 3).sigmoid()
target = torch.randint(0, 2, (10, 3))
print(multilabel_cm(preds, target))from torchmetrics import PrecisionAtFixedRecall, ROC
# Find precision at 90% recall
precision_at_recall = PrecisionAtFixedRecall(task="binary", min_recall=0.9)
preds = torch.randn(100).sigmoid()
target = torch.randint(0, 2, (100,))
precision_value, threshold = precision_at_recall(preds, target)
print(f"Precision: {precision_value:.3f} at threshold: {threshold:.3f}")
# Compute ROC curve
roc = ROC(task="binary")
fpr, tpr, thresholds = roc(preds, target)TaskType = Union["binary", "multiclass", "multilabel"]
AverageType = Union["micro", "macro", "weighted", "none", None]
MDMCAverageType = Union["global", "samplewise"]
ThresholdType = Union[float, List[float], Tensor]Install with Tessl CLI
npx tessl i tessl/pypi-torchmetrics