0
# Functional API
1
2
Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management. The functional API offers 350+ functions across all domains.
3
4
## Overview
5
6
The functional API provides stateless versions of all TorchMetrics metrics. These functions compute metrics directly on input tensors without maintaining internal state, making them ideal for one-off computations and integration into custom training loops.
7
8
All functional implementations are available under `torchmetrics.functional` with domain-specific submodules mirroring the class-based organization.
9
10
## Import Patterns
11
12
```python
13
# General functional import
14
import torchmetrics.functional as F
15
16
# Domain-specific functional imports
17
import torchmetrics.functional.classification as FC
18
import torchmetrics.functional.regression as FR
19
import torchmetrics.functional.audio as FA
20
import torchmetrics.functional.image as FI
21
import torchmetrics.functional.text as FT
22
```
23
24
## Capabilities
25
26
### Classification Functions
27
28
Functional implementations of all classification metrics with support for binary, multiclass, and multilabel tasks.
29
30
```python { .api }
31
def accuracy(
32
preds: Tensor,
33
target: Tensor,
34
task: str,
35
threshold: float = 0.5,
36
num_classes: Optional[int] = None,
37
num_labels: Optional[int] = None,
38
average: Optional[str] = "micro",
39
multidim_average: str = "global",
40
top_k: Optional[int] = None,
41
ignore_index: Optional[int] = None,
42
validate_args: bool = True,
43
) -> Tensor: ...
44
45
def f1_score(
46
preds: Tensor,
47
target: Tensor,
48
task: str,
49
threshold: float = 0.5,
50
num_classes: Optional[int] = None,
51
num_labels: Optional[int] = None,
52
average: Optional[str] = "micro",
53
multidim_average: str = "global",
54
top_k: Optional[int] = None,
55
ignore_index: Optional[int] = None,
56
validate_args: bool = True,
57
) -> Tensor: ...
58
59
def auroc(
60
preds: Tensor,
61
target: Tensor,
62
task: str,
63
num_classes: Optional[int] = None,
64
num_labels: Optional[int] = None,
65
average: Optional[str] = "macro",
66
max_fpr: Optional[float] = None,
67
thresholds: Optional[Union[int, List[float], Tensor]] = None,
68
ignore_index: Optional[int] = None,
69
validate_args: bool = True,
70
) -> Tensor: ...
71
72
def precision(
73
preds: Tensor,
74
target: Tensor,
75
task: str,
76
threshold: float = 0.5,
77
num_classes: Optional[int] = None,
78
num_labels: Optional[int] = None,
79
average: Optional[str] = "micro",
80
multidim_average: str = "global",
81
top_k: Optional[int] = None,
82
ignore_index: Optional[int] = None,
83
validate_args: bool = True,
84
) -> Tensor: ...
85
86
def recall(
87
preds: Tensor,
88
target: Tensor,
89
task: str,
90
threshold: float = 0.5,
91
num_classes: Optional[int] = None,
92
num_labels: Optional[int] = None,
93
average: Optional[str] = "micro",
94
multidim_average: str = "global",
95
top_k: Optional[int] = None,
96
ignore_index: Optional[int] = None,
97
validate_args: bool = True,
98
) -> Tensor: ...
99
100
def confusion_matrix(
101
preds: Tensor,
102
target: Tensor,
103
task: str,
104
num_classes: int,
105
threshold: float = 0.5,
106
num_labels: Optional[int] = None,
107
normalize: Optional[str] = None,
108
ignore_index: Optional[int] = None,
109
validate_args: bool = True,
110
) -> Tensor: ...
111
```
112
113
### Regression Functions
114
115
Functional implementations for regression metrics and correlation measures.
116
117
```python { .api }
118
def mean_squared_error(
119
preds: Tensor,
120
target: Tensor,
121
squared: bool = True,
122
num_outputs: int = 1,
123
) -> Tensor: ...
124
125
def mean_absolute_error(
126
preds: Tensor,
127
target: Tensor,
128
num_outputs: int = 1,
129
) -> Tensor: ...
130
131
def r2_score(
132
preds: Tensor,
133
target: Tensor,
134
num_outputs: int = 1,
135
multioutput: str = "uniform_average",
136
adjusted: int = 0,
137
) -> Tensor: ...
138
139
def pearson_corrcoef(
140
preds: Tensor,
141
target: Tensor,
142
num_outputs: int = 1,
143
) -> Tensor: ...
144
145
def spearman_corrcoef(
146
preds: Tensor,
147
target: Tensor,
148
num_outputs: int = 1,
149
) -> Tensor: ...
150
151
def cosine_similarity(
152
preds: Tensor,
153
target: Tensor,
154
reduction: str = "sum",
155
) -> Tensor: ...
156
```
157
158
### Audio Functions
159
160
Functional audio quality and separation metrics.
161
162
```python { .api }
163
def scale_invariant_signal_distortion_ratio(
164
preds: Tensor,
165
target: Tensor,
166
zero_mean: bool = True,
167
) -> Tensor: ...
168
169
def signal_distortion_ratio(
170
preds: Tensor,
171
target: Tensor,
172
use_cg_iter: Optional[int] = None,
173
filter_length: int = 512,
174
zero_mean: bool = True,
175
load_diag: Optional[float] = None,
176
) -> Tensor: ...
177
178
def permutation_invariant_training(
179
preds: Tensor,
180
target: Tensor,
181
metric: Callable,
182
mode: str = "speaker-wise",
183
eval_func: str = "max",
184
) -> Tensor: ...
185
186
def perceptual_evaluation_speech_quality(
187
preds: Tensor,
188
target: Tensor,
189
fs: int,
190
mode: str = "wb",
191
) -> Tensor: ...
192
```
193
194
### Image Functions
195
196
Functional image quality assessment metrics.
197
198
```python { .api }
199
def peak_signal_noise_ratio(
200
preds: Tensor,
201
target: Tensor,
202
data_range: Optional[float] = None,
203
base: float = 10.0,
204
reduction: str = "elementwise_mean",
205
) -> Tensor: ...
206
207
def structural_similarity_index_measure(
208
preds: Tensor,
209
target: Tensor,
210
gaussian_kernel: bool = True,
211
sigma: Union[float, Tuple[float, float]] = 1.5,
212
kernel_size: Union[int, Tuple[int, int]] = 11,
213
reduction: str = "elementwise_mean",
214
data_range: Optional[float] = None,
215
k1: float = 0.01,
216
k2: float = 0.03,
217
) -> Tensor: ...
218
219
def multiscale_structural_similarity_index_measure(
220
preds: Tensor,
221
target: Tensor,
222
gaussian_kernel: bool = True,
223
sigma: Union[float, Tuple[float, float]] = 1.5,
224
kernel_size: Union[int, Tuple[int, int]] = 11,
225
reduction: str = "elementwise_mean",
226
data_range: Optional[float] = None,
227
k1: float = 0.01,
228
k2: float = 0.03,
229
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
230
normalize: Optional[str] = "relu",
231
) -> Tensor: ...
232
233
def universal_image_quality_index(
234
preds: Tensor,
235
target: Tensor,
236
kernel_size: Union[int, Tuple[int, int]] = 8,
237
sigma: Union[float, Tuple[float, float]] = 1.5,
238
reduction: str = "elementwise_mean",
239
) -> Tensor: ...
240
```
241
242
### Text Functions
243
244
Functional NLP metrics for text evaluation.
245
246
```python { .api }
247
def bleu_score(
248
preds: Sequence[str],
249
target: Sequence[Sequence[str]],
250
n_gram: int = 4,
251
smooth: bool = False,
252
weights: Optional[Sequence[float]] = None,
253
) -> Tensor: ...
254
255
def rouge_score(
256
preds: Union[str, Sequence[str]],
257
target: Union[str, Sequence[str], Sequence[Sequence[str]]],
258
rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),
259
use_stemmer: bool = False,
260
normalizer: Optional[Callable[[str], str]] = None,
261
tokenizer: Optional[Callable[[str], Sequence[str]]] = None,
262
accumulate: str = "best",
263
) -> Dict[str, Tensor]: ...
264
265
def word_error_rate(
266
preds: Union[str, List[str]],
267
target: Union[str, List[str]],
268
) -> Tensor: ...
269
270
def char_error_rate(
271
preds: Union[str, List[str]],
272
target: Union[str, List[str]],
273
) -> Tensor: ...
274
275
def edit_distance(
276
preds: Union[str, List[str]],
277
target: Union[str, List[str]],
278
substitution_cost: int = 1,
279
reduction: Optional[str] = "mean",
280
) -> Tensor: ...
281
```
282
283
### Clustering Functions
284
285
Functional clustering evaluation metrics.
286
287
```python { .api }
288
def adjusted_rand_score(
289
preds: Tensor,
290
target: Tensor,
291
) -> Tensor: ...
292
293
def normalized_mutual_info_score(
294
preds: Tensor,
295
target: Tensor,
296
average: str = "arithmetic",
297
) -> Tensor: ...
298
299
def calinski_harabasz_score(
300
data: Tensor,
301
labels: Tensor,
302
) -> Tensor: ...
303
304
def davies_bouldin_score(
305
data: Tensor,
306
labels: Tensor,
307
) -> Tensor: ...
308
```
309
310
### Pairwise Functions
311
312
Functional pairwise distance and similarity measures.
313
314
```python { .api }
315
def pairwise_cosine_similarity(
316
x: Tensor,
317
y: Optional[Tensor] = None,
318
reduction: Optional[str] = None,
319
zero_diagonal: bool = True,
320
) -> Tensor: ...
321
322
def pairwise_euclidean_distance(
323
x: Tensor,
324
y: Optional[Tensor] = None,
325
reduction: Optional[str] = None,
326
zero_diagonal: bool = True,
327
) -> Tensor: ...
328
329
def pairwise_manhattan_distance(
330
x: Tensor,
331
y: Optional[Tensor] = None,
332
reduction: Optional[str] = None,
333
zero_diagonal: bool = True,
334
) -> Tensor: ...
335
336
def pairwise_minkowski_distance(
337
x: Tensor,
338
y: Optional[Tensor] = None,
339
p: float = 2.0,
340
reduction: Optional[str] = None,
341
zero_diagonal: bool = True,
342
) -> Tensor: ...
343
```
344
345
## Usage Examples
346
347
### Basic Functional Usage
348
349
```python
350
import torch
351
import torchmetrics.functional as F
352
353
# Binary classification
354
preds = torch.tensor([0.1, 0.9, 0.8, 0.4])
355
target = torch.tensor([0, 1, 1, 0])
356
357
# Compute metrics directly
358
acc = F.accuracy(preds, target, task="binary")
359
f1 = F.f1_score(preds, target, task="binary")
360
auc = F.auroc(preds, target, task="binary")
361
362
print(f"Accuracy: {acc:.4f}")
363
print(f"F1 Score: {f1:.4f}")
364
print(f"AUROC: {auc:.4f}")
365
```
366
367
### Multiclass Classification
368
369
```python
370
import torchmetrics.functional.classification as FC
371
372
# Multiclass predictions
373
preds = torch.randn(100, 5).softmax(dim=-1)
374
target = torch.randint(0, 5, (100,))
375
376
# Compute various metrics
377
acc = FC.multiclass_accuracy(preds, target, num_classes=5)
378
precision = FC.multiclass_precision(preds, target, num_classes=5, average="macro")
379
recall = FC.multiclass_recall(preds, target, num_classes=5, average="macro")
380
conf_matrix = FC.multiclass_confusion_matrix(preds, target, num_classes=5)
381
382
print(f"Accuracy: {acc:.4f}")
383
print(f"Macro Precision: {precision:.4f}")
384
print(f"Macro Recall: {recall:.4f}")
385
print(f"Confusion Matrix Shape: {conf_matrix.shape}")
386
```
387
388
### Regression Metrics
389
390
```python
391
import torchmetrics.functional.regression as FR
392
393
# Regression predictions
394
preds = torch.randn(50, 1)
395
target = torch.randn(50, 1)
396
397
# Compute regression metrics
398
mse = FR.mean_squared_error(preds, target)
399
mae = FR.mean_absolute_error(preds, target)
400
r2 = FR.r2_score(preds, target)
401
pearson = FR.pearson_corrcoef(preds, target)
402
403
print(f"MSE: {mse:.4f}")
404
print(f"MAE: {mae:.4f}")
405
print(f"R²: {r2:.4f}")
406
print(f"Pearson Correlation: {pearson:.4f}")
407
```
408
409
### Image Quality Assessment
410
411
```python
412
import torchmetrics.functional.image as FI
413
414
# Image tensors
415
preds = torch.rand(4, 3, 256, 256)
416
target = torch.rand(4, 3, 256, 256)
417
418
# Compute image quality metrics
419
psnr = FI.peak_signal_noise_ratio(preds, target)
420
ssim = FI.structural_similarity_index_measure(preds, target)
421
ms_ssim = FI.multiscale_structural_similarity_index_measure(preds, target)
422
423
print(f"PSNR: {psnr:.4f}")
424
print(f"SSIM: {ssim:.4f}")
425
print(f"MS-SSIM: {ms_ssim:.4f}")
426
```
427
428
### Text Evaluation
429
430
```python
431
import torchmetrics.functional.text as FT
432
433
# Text evaluation
434
preds = ["the cat is on the mat"]
435
target = [["there is a cat on the mat", "a cat is on the mat"]]
436
437
# Compute text metrics
438
bleu = FT.bleu_score(preds, target)
439
rouge_scores = FT.rouge_score(preds[0], target[0])
440
441
print(f"BLEU Score: {bleu:.4f}")
442
print(f"ROUGE-1 F1: {rouge_scores['rouge1_fmeasure']:.4f}")
443
print(f"ROUGE-L F1: {rouge_scores['rougeL_fmeasure']:.4f}")
444
445
# Error rates
446
pred_text = ["this is a test"]
447
target_text = ["this is the test"]
448
wer = FT.word_error_rate(pred_text, target_text)
449
cer = FT.char_error_rate(pred_text, target_text)
450
451
print(f"Word Error Rate: {wer:.4f}")
452
print(f"Character Error Rate: {cer:.4f}")
453
```
454
455
### Audio Quality
456
457
```python
458
import torchmetrics.functional.audio as FA
459
460
# Audio signals
461
preds = torch.randn(4, 8000) # 4 samples, 8000 time steps
462
target = torch.randn(4, 8000)
463
464
# Compute audio metrics
465
si_sdr = FA.scale_invariant_signal_distortion_ratio(preds, target)
466
si_snr = FA.scale_invariant_signal_noise_ratio(preds, target)
467
468
print(f"SI-SDR: {si_sdr:.4f} dB")
469
print(f"SI-SNR: {si_snr:.4f} dB")
470
```
471
472
### Pairwise Distances
473
474
```python
475
import torchmetrics.functional.pairwise as FP
476
477
# Feature vectors
478
x = torch.randn(100, 64) # 100 samples, 64-dim features
479
y = torch.randn(50, 64) # 50 samples, 64-dim features
480
481
# Compute pairwise similarities and distances
482
cosine_sim = FP.pairwise_cosine_similarity(x, y)
483
euclidean_dist = FP.pairwise_euclidean_distance(x, y)
484
manhattan_dist = FP.pairwise_manhattan_distance(x, y)
485
486
print(f"Cosine Similarity Shape: {cosine_sim.shape}") # (100, 50)
487
print(f"Euclidean Distance Shape: {euclidean_dist.shape}") # (100, 50)
488
print(f"Manhattan Distance Shape: {manhattan_dist.shape}") # (100, 50)
489
```
490
491
### Clustering Evaluation
492
493
```python
494
import torchmetrics.functional.clustering as FCL
495
496
# Clustering results
497
pred_clusters = torch.randint(0, 3, (100,))
498
true_clusters = torch.randint(0, 3, (100,))
499
500
# Clustering metrics
501
ari = FCL.adjusted_rand_score(pred_clusters, true_clusters)
502
nmi = FCL.normalized_mutual_info_score(pred_clusters, true_clusters)
503
504
print(f"Adjusted Rand Index: {ari:.4f}")
505
print(f"Normalized Mutual Info: {nmi:.4f}")
506
507
# Internal clustering metrics (require data)
508
data = torch.randn(100, 10) # 100 samples, 10 features
509
ch_score = FCL.calinski_harabasz_score(data, pred_clusters)
510
db_score = FCL.davies_bouldin_score(data, pred_clusters)
511
512
print(f"Calinski-Harabasz Score: {ch_score:.4f}")
513
print(f"Davies-Bouldin Score: {db_score:.4f}")
514
```
515
516
## Functional vs Class-based API
517
518
### When to Use Functional API
519
520
- One-off metric computations
521
- Custom training loops without Lightning
522
- Minimal memory overhead requirements
523
- Integration with existing codebases
524
- Research experiments requiring flexibility
525
526
### When to Use Class-based API
527
528
- Accumulating metrics across batches
529
- Distributed training scenarios
530
- PyTorch Lightning integration
531
- Automatic state management needed
532
- Complex metric tracking workflows
533
534
## Types
535
536
```python { .api }
537
from typing import Union, Optional, List, Dict, Tuple, Sequence, Callable, Any
538
import torch
539
from torch import Tensor
540
541
# Common functional types
542
FunctionalOutput = Union[Tensor, Dict[str, Tensor], Tuple[Tensor, ...]]
543
TaskType = Union["binary", "multiclass", "multilabel"]
544
AverageType = Union["micro", "macro", "weighted", "none", None]
545
ReductionType = Union["mean", "sum", "none", "elementwise_mean"]
546
```