0
# TorchMetrics
1
2
A comprehensive metrics library for PyTorch and PyTorch Lightning, providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other machine learning domains. TorchMetrics offers distributed and scalable metric computation with consistent APIs, automatic device handling, and seamless integration with PyTorch workflows.
3
4
## Package Information
5
6
- **Package Name**: torchmetrics
7
- **Package Type**: Library
8
- **Language**: Python
9
- **Installation**: `pip install torchmetrics`
10
11
## Core Imports
12
13
```python
14
import torchmetrics
15
```
16
17
For functional API:
18
```python
19
import torchmetrics.functional as F
20
```
21
22
For specific metrics:
23
```python
24
from torchmetrics import Accuracy, AUROC, F1Score
25
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy
26
from torchmetrics.regression import MeanSquaredError, R2Score
27
```
28
29
## Basic Usage
30
31
```python
32
import torch
33
import torchmetrics
34
35
# Initialize metrics
36
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
37
f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)
38
39
# Create sample predictions and targets
40
preds = torch.randn(10, 3).softmax(dim=-1)
41
target = torch.randint(0, 3, (10,))
42
43
# Compute metrics
44
acc_score = accuracy(preds, target)
45
f1_score = f1(preds, target)
46
47
print(f"Accuracy: {acc_score:.4f}")
48
print(f"F1 Score: {f1_score:.4f}")
49
50
# Using functional API
51
from torchmetrics.functional import accuracy, f1_score as f1_func
52
53
acc_functional = accuracy(preds, target, task="multiclass", num_classes=3)
54
f1_functional = f1_func(preds, target, task="multiclass", num_classes=3)
55
56
print(f"Functional Accuracy: {acc_functional:.4f}")
57
print(f"Functional F1: {f1_functional:.4f}")
58
```
59
60
## Architecture
61
62
TorchMetrics follows a dual-interface design pattern:
63
64
- **Modular (Class-based) Interface**: Stateful metric classes that accumulate values across batches and provide automatic synchronization in distributed settings
65
- **Functional Interface**: Stateless functions for single-batch computations and one-off metric calculations
66
- **MetricCollection**: Container for organizing and computing multiple metrics simultaneously
67
- **Wrappers**: Advanced functionality for bootstrapping, per-class computation, and multi-task scenarios
68
69
All metrics inherit from the base `Metric` class, ensuring consistent behavior, automatic device handling, state management, and distributed computation support across the entire library.
70
71
## Capabilities
72
73
### Classification Metrics
74
75
Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. Includes accuracy, precision, recall, F-scores, ROC/AUC, confusion matrices, and threshold-based metrics.
76
77
```python { .api }
78
class Accuracy(Metric):
79
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
80
81
class AUROC(Metric):
82
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
83
84
class F1Score(Metric):
85
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
86
87
class Precision(Metric):
88
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
89
90
class Recall(Metric):
91
def __init__(self, task: str, num_classes: int = None, **kwargs): ...
92
93
class ConfusionMatrix(Metric):
94
def __init__(self, task: str, num_classes: int, **kwargs): ...
95
```
96
97
[Classification](./classification.md)
98
99
### Regression Metrics
100
101
Metrics for regression tasks including error measurements, correlation coefficients, and explained variance measures for continuous target prediction evaluation.
102
103
```python { .api }
104
class MeanSquaredError(Metric):
105
def __init__(self, **kwargs): ...
106
107
class MeanAbsoluteError(Metric):
108
def __init__(self, **kwargs): ...
109
110
class R2Score(Metric):
111
def __init__(self, num_outputs: int = 1, **kwargs): ...
112
113
class PearsonCorrCoef(Metric):
114
def __init__(self, num_outputs: int = 1, **kwargs): ...
115
```
116
117
[Regression](./regression.md)
118
119
### Audio Metrics
120
121
Specialized metrics for audio processing and speech evaluation including signal-to-noise ratios, perceptual quality measures, and separation metrics.
122
123
```python { .api }
124
class ScaleInvariantSignalDistortionRatio(Metric):
125
def __init__(self, **kwargs): ...
126
127
class PermutationInvariantTraining(Metric):
128
def __init__(self, metric, mode: str = "speaker-wise", **kwargs): ...
129
130
class PerceptualEvaluationSpeechQuality(Metric):
131
def __init__(self, fs: int, mode: str = "wb", **kwargs): ...
132
```
133
134
[Audio](./audio.md)
135
136
### Image Quality Metrics
137
138
Image quality assessment metrics including structural similarity, peak signal-to-noise ratio, and perceptual quality measures for computer vision applications.
139
140
```python { .api }
141
class StructuralSimilarityIndexMeasure(Metric):
142
def __init__(self, **kwargs): ...
143
144
class PeakSignalNoiseRatio(Metric):
145
def __init__(self, **kwargs): ...
146
147
class FrechetInceptionDistance(Metric):
148
def __init__(self, feature: int = 2048, **kwargs): ...
149
```
150
151
[Image](./image.md)
152
153
### Text Metrics
154
155
Natural language processing metrics for translation, summarization, and text generation evaluation including BLEU, ROUGE, and semantic similarity measures.
156
157
```python { .api }
158
class BLEUScore(Metric):
159
def __init__(self, n_gram: int = 4, **kwargs): ...
160
161
class ROUGEScore(Metric):
162
def __init__(self, rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"), **kwargs): ...
163
164
class BERTScore(Metric):
165
def __init__(self, model_name_or_path: str = "distilbert-base-uncased", **kwargs): ...
166
```
167
168
[Text](./text.md)
169
170
### Detection Metrics
171
172
Object detection and instance segmentation metrics for evaluating bounding box predictions, IoU calculations, and mean average precision.
173
174
```python { .api }
175
class MeanAveragePrecision(Metric):
176
def __init__(self, **kwargs): ...
177
178
class IntersectionOverUnion(Metric):
179
def __init__(self, **kwargs): ...
180
181
class PanopticQuality(Metric):
182
def __init__(self, **kwargs): ...
183
```
184
185
[Detection](./detection.md)
186
187
### Clustering Metrics
188
189
Unsupervised learning evaluation metrics including mutual information, Rand indices, and silhouette analysis for cluster quality assessment.
190
191
```python { .api }
192
class AdjustedRandScore(Metric):
193
def __init__(self, **kwargs): ...
194
195
class NormalizedMutualInfoScore(Metric):
196
def __init__(self, average: str = "arithmetic", **kwargs): ...
197
198
class CalinskiHarabaszScore(Metric):
199
def __init__(self, **kwargs): ...
200
```
201
202
[Clustering](./clustering.md)
203
204
### Information Retrieval Metrics
205
206
Metrics for ranking and retrieval systems including precision at k, mean average precision, and normalized discounted cumulative gain.
207
208
```python { .api }
209
class RetrievalMAP(Metric):
210
def __init__(self, **kwargs): ...
211
212
class RetrievalNormalizedDCG(Metric):
213
def __init__(self, k: int = None, **kwargs): ...
214
215
class RetrievalMRR(Metric):
216
def __init__(self, **kwargs): ...
217
```
218
219
[Retrieval](./retrieval.md)
220
221
### Segmentation Metrics
222
223
Semantic and instance segmentation evaluation including Dice coefficients, Intersection over Union, and Hausdorff distance for pixel-level predictions.
224
225
```python { .api }
226
class DiceScore(Metric):
227
def __init__(self, **kwargs): ...
228
229
class MeanIoU(Metric):
230
def __init__(self, num_classes: int, **kwargs): ...
231
232
class HausdorffDistance(Metric):
233
def __init__(self, **kwargs): ...
234
```
235
236
[Segmentation](./segmentation.md)
237
238
### Multimodal Metrics
239
240
Metrics for evaluating multimodal AI systems including video-audio synchronization and cross-modal quality assessment.
241
242
```python { .api }
243
class LipVertexError(Metric):
244
def __init__(self, **kwargs): ...
245
246
class CLIPScore(Metric):
247
def __init__(self, model_name_or_path: str = "openai/clip-vit-base-patch16", **kwargs): ...
248
```
249
250
[Multimodal](./multimodal.md)
251
252
### Nominal/Categorical Metrics
253
254
Statistical measures for analyzing associations and agreements between categorical variables.
255
256
```python { .api }
257
class CramersV(Metric):
258
def __init__(self, num_classes: int, **kwargs): ...
259
260
class FleissKappa(Metric):
261
def __init__(self, mode: str = "counts", **kwargs): ...
262
```
263
264
[Nominal](./nominal.md)
265
266
### Shape Metrics
267
268
Metrics for analyzing geometric shapes and spatial configurations.
269
270
```python { .api }
271
class ProcrustesDisparity(Metric):
272
def __init__(self, **kwargs): ...
273
```
274
275
[Shape](./shape.md)
276
277
### Video Metrics
278
279
Specialized metrics for video quality assessment and evaluation.
280
281
```python { .api }
282
class VideoMultiMethodAssessmentFusion(Metric):
283
def __init__(self, **kwargs): ...
284
```
285
286
[Video](./video.md)
287
288
### Aggregation and Utilities
289
290
Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics and bootstrapping.
291
292
```python { .api }
293
class MetricCollection:
294
def __init__(self, metrics: Union[Dict[str, Metric], List[Metric]], **kwargs): ...
295
296
class MeanMetric(Metric):
297
def __init__(self, **kwargs): ...
298
299
class SumMetric(Metric):
300
def __init__(self, **kwargs): ...
301
302
class BootStrapper:
303
def __init__(self, base_metric: Metric, num_bootstraps: int = 100, **kwargs): ...
304
```
305
306
[Utilities](./utilities.md)
307
308
### Functional API
309
310
Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management.
311
312
```python { .api }
313
def accuracy(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...
314
def f1_score(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...
315
def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: ...
316
def structural_similarity_index_measure(preds: Tensor, target: Tensor, **kwargs) -> Tensor: ...
317
```
318
319
[Functional](./functional.md)
320
321
## Types
322
323
Core imports for TorchMetrics:
324
325
```python
326
from typing import Union, Optional, Tuple, Dict, List, Any, Callable
327
import torch
328
from torch import Tensor
329
```
330
331
Common type aliases:
332
333
```python { .api }
334
TaskType = Union["binary", "multiclass", "multilabel"]
335
AverageType = Union["micro", "macro", "weighted", "none", None]
336
MDMCAverageType = Union["global", "samplewise"]
337
ThresholdType = Union[float, List[float], Tensor]
338
```
339
340
Base metric class:
341
342
```python { .api }
343
class Metric:
344
"""Base class for all metrics."""
345
def __init__(self, **kwargs): ...
346
def __call__(self, *args, **kwargs) -> Any: ...
347
def update(self, *args, **kwargs) -> None: ...
348
def compute(self) -> Any: ...
349
def reset(self) -> None: ...
350
def to(self, device: Union[str, torch.device]) -> "Metric": ...
351
def forward(self, *args, **kwargs) -> Any: ...
352
```