0
# Classification Metrics
1
2
Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. All classification metrics support automatic task detection and provide consistent APIs across different classification types with variants for each task type.
3
4
## Capabilities
5
6
### Accuracy Metrics
7
8
Measures the proportion of correct predictions among all predictions made.
9
10
```python { .api }
11
class Accuracy(Metric):
12
def __init__(
13
self,
14
task: str,
15
threshold: float = 0.5,
16
num_classes: Optional[int] = None,
17
num_labels: Optional[int] = None,
18
average: Optional[str] = "micro",
19
multidim_average: str = "global",
20
top_k: Optional[int] = None,
21
ignore_index: Optional[int] = None,
22
validate_args: bool = True,
23
**kwargs
24
): ...
25
26
class BinaryAccuracy(Metric):
27
def __init__(
28
self,
29
threshold: float = 0.5,
30
multidim_average: str = "global",
31
ignore_index: Optional[int] = None,
32
validate_args: bool = True,
33
**kwargs
34
): ...
35
36
class MulticlassAccuracy(Metric):
37
def __init__(
38
self,
39
num_classes: int,
40
average: Optional[str] = "micro",
41
top_k: Optional[int] = None,
42
multidim_average: str = "global",
43
ignore_index: Optional[int] = None,
44
validate_args: bool = True,
45
**kwargs
46
): ...
47
48
class MultilabelAccuracy(Metric):
49
def __init__(
50
self,
51
num_labels: int,
52
threshold: float = 0.5,
53
average: Optional[str] = "micro",
54
multidim_average: str = "global",
55
ignore_index: Optional[int] = None,
56
validate_args: bool = True,
57
**kwargs
58
): ...
59
```
60
61
### Area Under ROC Curve (AUROC)
62
63
Computes Area Under the Receiver Operating Characteristic Curve, measuring the model's ability to distinguish between classes.
64
65
```python { .api }
66
class AUROC(Metric):
67
def __init__(
68
self,
69
task: str,
70
num_classes: Optional[int] = None,
71
num_labels: Optional[int] = None,
72
average: Optional[str] = "macro",
73
max_fpr: Optional[float] = None,
74
thresholds: Optional[Union[int, List[float], Tensor]] = None,
75
ignore_index: Optional[int] = None,
76
validate_args: bool = True,
77
**kwargs
78
): ...
79
80
class BinaryAUROC(Metric):
81
def __init__(
82
self,
83
max_fpr: Optional[float] = None,
84
thresholds: Optional[Union[int, List[float], Tensor]] = None,
85
ignore_index: Optional[int] = None,
86
validate_args: bool = True,
87
**kwargs
88
): ...
89
90
class MulticlassAUROC(Metric):
91
def __init__(
92
self,
93
num_classes: int,
94
average: Optional[str] = "macro",
95
thresholds: Optional[Union[int, List[float], Tensor]] = None,
96
ignore_index: Optional[int] = None,
97
validate_args: bool = True,
98
**kwargs
99
): ...
100
101
class MultilabelAUROC(Metric):
102
def __init__(
103
self,
104
num_labels: int,
105
average: Optional[str] = "macro",
106
thresholds: Optional[Union[int, List[float], Tensor]] = None,
107
ignore_index: Optional[int] = None,
108
validate_args: bool = True,
109
**kwargs
110
): ...
111
```
112
113
### ROC Curves
114
115
Computes Receiver Operating Characteristic curves for visualization and analysis.
116
117
```python { .api }
118
class ROC(Metric):
119
def __init__(
120
self,
121
task: str,
122
num_classes: Optional[int] = None,
123
num_labels: Optional[int] = None,
124
thresholds: Optional[Union[int, List[float], Tensor]] = None,
125
ignore_index: Optional[int] = None,
126
validate_args: bool = True,
127
**kwargs
128
): ...
129
130
class BinaryROC(Metric):
131
def __init__(
132
self,
133
thresholds: Optional[Union[int, List[float], Tensor]] = None,
134
ignore_index: Optional[int] = None,
135
validate_args: bool = True,
136
**kwargs
137
): ...
138
139
class MulticlassROC(Metric):
140
def __init__(
141
self,
142
num_classes: int,
143
thresholds: Optional[Union[int, List[float], Tensor]] = None,
144
ignore_index: Optional[int] = None,
145
validate_args: bool = True,
146
**kwargs
147
): ...
148
149
class MultilabelROC(Metric):
150
def __init__(
151
self,
152
num_labels: int,
153
thresholds: Optional[Union[int, List[float], Tensor]] = None,
154
ignore_index: Optional[int] = None,
155
validate_args: bool = True,
156
**kwargs
157
): ...
158
```
159
160
### Precision and Recall
161
162
Measures the proportion of relevant instances among retrieved instances (precision) and retrieved instances among relevant instances (recall).
163
164
```python { .api }
165
class Precision(Metric):
166
def __init__(
167
self,
168
task: str,
169
threshold: float = 0.5,
170
num_classes: Optional[int] = None,
171
num_labels: Optional[int] = None,
172
average: Optional[str] = "micro",
173
multidim_average: str = "global",
174
top_k: Optional[int] = None,
175
ignore_index: Optional[int] = None,
176
validate_args: bool = True,
177
**kwargs
178
): ...
179
180
class Recall(Metric):
181
def __init__(
182
self,
183
task: str,
184
threshold: float = 0.5,
185
num_classes: Optional[int] = None,
186
num_labels: Optional[int] = None,
187
average: Optional[str] = "micro",
188
multidim_average: str = "global",
189
top_k: Optional[int] = None,
190
ignore_index: Optional[int] = None,
191
validate_args: bool = True,
192
**kwargs
193
): ...
194
```
195
196
Each precision and recall metric also has Binary, Multiclass, and Multilabel variants with task-specific parameters.
197
198
### F-Scores
199
200
Harmonic mean of precision and recall, with F1 being the most commonly used (beta=1).
201
202
```python { .api }
203
class F1Score(Metric):
204
def __init__(
205
self,
206
task: str,
207
threshold: float = 0.5,
208
num_classes: Optional[int] = None,
209
num_labels: Optional[int] = None,
210
average: Optional[str] = "micro",
211
multidim_average: str = "global",
212
top_k: Optional[int] = None,
213
ignore_index: Optional[int] = None,
214
validate_args: bool = True,
215
**kwargs
216
): ...
217
218
class FBetaScore(Metric):
219
def __init__(
220
self,
221
task: str,
222
beta: float = 1.0,
223
threshold: float = 0.5,
224
num_classes: Optional[int] = None,
225
num_labels: Optional[int] = None,
226
average: Optional[str] = "micro",
227
multidim_average: str = "global",
228
top_k: Optional[int] = None,
229
ignore_index: Optional[int] = None,
230
validate_args: bool = True,
231
**kwargs
232
): ...
233
```
234
235
### Average Precision
236
237
Computes average precision score, which summarizes a precision-recall curve as the weighted mean of precisions.
238
239
```python { .api }
240
class AveragePrecision(Metric):
241
def __init__(
242
self,
243
task: str,
244
num_classes: Optional[int] = None,
245
num_labels: Optional[int] = None,
246
average: Optional[str] = "macro",
247
thresholds: Optional[Union[int, List[float], Tensor]] = None,
248
ignore_index: Optional[int] = None,
249
validate_args: bool = True,
250
**kwargs
251
): ...
252
```
253
254
### Confusion Matrix
255
256
Computes confusion matrix for evaluating classification accuracy with detailed breakdown of true/false positives and negatives.
257
258
```python { .api }
259
class ConfusionMatrix(Metric):
260
def __init__(
261
self,
262
task: str,
263
num_classes: int,
264
threshold: float = 0.5,
265
num_labels: Optional[int] = None,
266
normalize: Optional[str] = None,
267
ignore_index: Optional[int] = None,
268
validate_args: bool = True,
269
**kwargs
270
): ...
271
```
272
273
### Statistical Scores
274
275
Computes true positives, false positives, true negatives, false negatives, and support statistics.
276
277
```python { .api }
278
class StatScores(Metric):
279
def __init__(
280
self,
281
task: str,
282
threshold: float = 0.5,
283
num_classes: Optional[int] = None,
284
num_labels: Optional[int] = None,
285
average: Optional[str] = "micro",
286
multidim_average: str = "global",
287
top_k: Optional[int] = None,
288
ignore_index: Optional[int] = None,
289
validate_args: bool = True,
290
**kwargs
291
): ...
292
```
293
294
### Threshold-Based Metrics
295
296
Metrics that find optimal thresholds or evaluate performance at specific operating points.
297
298
```python { .api }
299
class PrecisionAtFixedRecall(Metric):
300
def __init__(
301
self,
302
task: str,
303
min_recall: float,
304
num_classes: Optional[int] = None,
305
num_labels: Optional[int] = None,
306
thresholds: Optional[Union[int, List[float], Tensor]] = None,
307
ignore_index: Optional[int] = None,
308
validate_args: bool = True,
309
**kwargs
310
): ...
311
312
class RecallAtFixedPrecision(Metric):
313
def __init__(
314
self,
315
task: str,
316
min_precision: float,
317
num_classes: Optional[int] = None,
318
num_labels: Optional[int] = None,
319
thresholds: Optional[Union[int, List[float], Tensor]] = None,
320
ignore_index: Optional[int] = None,
321
validate_args: bool = True,
322
**kwargs
323
): ...
324
325
class SensitivityAtSpecificity(Metric):
326
def __init__(
327
self,
328
task: str,
329
min_specificity: float,
330
num_classes: Optional[int] = None,
331
num_labels: Optional[int] = None,
332
thresholds: Optional[Union[int, List[float], Tensor]] = None,
333
ignore_index: Optional[int] = None,
334
validate_args: bool = True,
335
**kwargs
336
): ...
337
338
class SpecificityAtSensitivity(Metric):
339
def __init__(
340
self,
341
task: str,
342
min_sensitivity: float,
343
num_classes: Optional[int] = None,
344
num_labels: Optional[int] = None,
345
thresholds: Optional[Union[int, List[float], Tensor]] = None,
346
ignore_index: Optional[int] = None,
347
validate_args: bool = True,
348
**kwargs
349
): ...
350
```
351
352
### Advanced Classification Metrics
353
354
Specialized metrics for specific classification scenarios.
355
356
```python { .api }
357
class CohenKappa(Metric):
358
def __init__(
359
self,
360
task: str,
361
num_classes: int,
362
threshold: float = 0.5,
363
num_labels: Optional[int] = None,
364
weights: Optional[str] = None,
365
ignore_index: Optional[int] = None,
366
validate_args: bool = True,
367
**kwargs
368
): ...
369
370
class MatthewsCorrCoef(Metric):
371
def __init__(
372
self,
373
task: str,
374
threshold: float = 0.5,
375
num_classes: Optional[int] = None,
376
num_labels: Optional[int] = None,
377
ignore_index: Optional[int] = None,
378
validate_args: bool = True,
379
**kwargs
380
): ...
381
382
class JaccardIndex(Metric):
383
def __init__(
384
self,
385
task: str,
386
threshold: float = 0.5,
387
num_classes: Optional[int] = None,
388
num_labels: Optional[int] = None,
389
average: Optional[str] = "micro",
390
ignore_index: Optional[int] = None,
391
validate_args: bool = True,
392
**kwargs
393
): ...
394
395
class HammingDistance(Metric):
396
def __init__(
397
self,
398
task: str,
399
threshold: float = 0.5,
400
num_classes: Optional[int] = None,
401
num_labels: Optional[int] = None,
402
multidim_average: str = "global",
403
ignore_index: Optional[int] = None,
404
validate_args: bool = True,
405
**kwargs
406
): ...
407
408
class ExactMatch(Metric):
409
def __init__(
410
self,
411
task: str,
412
threshold: float = 0.5,
413
num_classes: Optional[int] = None,
414
num_labels: Optional[int] = None,
415
multidim_average: str = "global",
416
ignore_index: Optional[int] = None,
417
validate_args: bool = True,
418
**kwargs
419
): ...
420
```
421
422
### Calibration and Ranking Metrics
423
424
Metrics for evaluating model calibration and ranking quality.
425
426
```python { .api }
427
class CalibrationError(Metric):
428
def __init__(
429
self,
430
task: str,
431
n_bins: int = 15,
432
norm: str = "l1",
433
num_classes: Optional[int] = None,
434
num_labels: Optional[int] = None,
435
ignore_index: Optional[int] = None,
436
validate_args: bool = True,
437
**kwargs
438
): ...
439
440
class MultilabelRankingAveragePrecision(Metric):
441
def __init__(
442
self,
443
num_labels: int,
444
validate_args: bool = True,
445
**kwargs
446
): ...
447
448
class MultilabelRankingLoss(Metric):
449
def __init__(
450
self,
451
num_labels: int,
452
validate_args: bool = True,
453
**kwargs
454
): ...
455
456
class MultilabelCoverageError(Metric):
457
def __init__(
458
self,
459
num_labels: int,
460
validate_args: bool = True,
461
**kwargs
462
): ...
463
```
464
465
## Usage Examples
466
467
### Basic Classification
468
469
```python
470
import torch
471
from torchmetrics import Accuracy, F1Score, ConfusionMatrix
472
473
# Binary classification
474
binary_acc = Accuracy(task="binary")
475
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
476
target = torch.tensor([0, 1, 1, 0])
477
print(binary_acc(preds, target))
478
479
# Multiclass classification
480
multiclass_f1 = F1Score(task="multiclass", num_classes=3, average="macro")
481
preds = torch.randn(10, 3).softmax(dim=-1)
482
target = torch.randint(0, 3, (10,))
483
print(multiclass_f1(preds, target))
484
485
# Multilabel classification
486
multilabel_cm = ConfusionMatrix(task="multilabel", num_labels=3)
487
preds = torch.randn(10, 3).sigmoid()
488
target = torch.randint(0, 2, (10, 3))
489
print(multilabel_cm(preds, target))
490
```
491
492
### Threshold-based Metrics
493
494
```python
495
from torchmetrics import PrecisionAtFixedRecall, ROC
496
497
# Find precision at 90% recall
498
precision_at_recall = PrecisionAtFixedRecall(task="binary", min_recall=0.9)
499
preds = torch.randn(100).sigmoid()
500
target = torch.randint(0, 2, (100,))
501
precision_value, threshold = precision_at_recall(preds, target)
502
print(f"Precision: {precision_value:.3f} at threshold: {threshold:.3f}")
503
504
# Compute ROC curve
505
roc = ROC(task="binary")
506
fpr, tpr, thresholds = roc(preds, target)
507
```
508
509
## Types
510
511
```python { .api }
512
TaskType = Union["binary", "multiclass", "multilabel"]
513
AverageType = Union["micro", "macro", "weighted", "none", None]
514
MDMCAverageType = Union["global", "samplewise"]
515
ThresholdType = Union[float, List[float], Tensor]
516
```