or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio.mdclassification.mdclustering.mddetection.mdfunctional.mdimage.mdindex.mdmultimodal.mdnominal.mdregression.mdretrieval.mdsegmentation.mdshape.mdtext.mdutilities.mdvideo.md

functional.mddocs/

0

# Functional API

1

2

Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management. The functional API offers 350+ functions across all domains.

3

4

## Overview

5

6

The functional API provides stateless versions of all TorchMetrics metrics. These functions compute metrics directly on input tensors without maintaining internal state, making them ideal for one-off computations and integration into custom training loops.

7

8

All functional implementations are available under `torchmetrics.functional` with domain-specific submodules mirroring the class-based organization.

9

10

## Import Patterns

11

12

```python

13

# General functional import

14

import torchmetrics.functional as F

15

16

# Domain-specific functional imports

17

import torchmetrics.functional.classification as FC

18

import torchmetrics.functional.regression as FR

19

import torchmetrics.functional.audio as FA

20

import torchmetrics.functional.image as FI

21

import torchmetrics.functional.text as FT

22

```

23

24

## Capabilities

25

26

### Classification Functions

27

28

Functional implementations of all classification metrics with support for binary, multiclass, and multilabel tasks.

29

30

```python { .api }

31

def accuracy(

32

preds: Tensor,

33

target: Tensor,

34

task: str,

35

threshold: float = 0.5,

36

num_classes: Optional[int] = None,

37

num_labels: Optional[int] = None,

38

average: Optional[str] = "micro",

39

multidim_average: str = "global",

40

top_k: Optional[int] = None,

41

ignore_index: Optional[int] = None,

42

validate_args: bool = True,

43

) -> Tensor: ...

44

45

def f1_score(

46

preds: Tensor,

47

target: Tensor,

48

task: str,

49

threshold: float = 0.5,

50

num_classes: Optional[int] = None,

51

num_labels: Optional[int] = None,

52

average: Optional[str] = "micro",

53

multidim_average: str = "global",

54

top_k: Optional[int] = None,

55

ignore_index: Optional[int] = None,

56

validate_args: bool = True,

57

) -> Tensor: ...

58

59

def auroc(

60

preds: Tensor,

61

target: Tensor,

62

task: str,

63

num_classes: Optional[int] = None,

64

num_labels: Optional[int] = None,

65

average: Optional[str] = "macro",

66

max_fpr: Optional[float] = None,

67

thresholds: Optional[Union[int, List[float], Tensor]] = None,

68

ignore_index: Optional[int] = None,

69

validate_args: bool = True,

70

) -> Tensor: ...

71

72

def precision(

73

preds: Tensor,

74

target: Tensor,

75

task: str,

76

threshold: float = 0.5,

77

num_classes: Optional[int] = None,

78

num_labels: Optional[int] = None,

79

average: Optional[str] = "micro",

80

multidim_average: str = "global",

81

top_k: Optional[int] = None,

82

ignore_index: Optional[int] = None,

83

validate_args: bool = True,

84

) -> Tensor: ...

85

86

def recall(

87

preds: Tensor,

88

target: Tensor,

89

task: str,

90

threshold: float = 0.5,

91

num_classes: Optional[int] = None,

92

num_labels: Optional[int] = None,

93

average: Optional[str] = "micro",

94

multidim_average: str = "global",

95

top_k: Optional[int] = None,

96

ignore_index: Optional[int] = None,

97

validate_args: bool = True,

98

) -> Tensor: ...

99

100

def confusion_matrix(

101

preds: Tensor,

102

target: Tensor,

103

task: str,

104

num_classes: int,

105

threshold: float = 0.5,

106

num_labels: Optional[int] = None,

107

normalize: Optional[str] = None,

108

ignore_index: Optional[int] = None,

109

validate_args: bool = True,

110

) -> Tensor: ...

111

```

112

113

### Regression Functions

114

115

Functional implementations for regression metrics and correlation measures.

116

117

```python { .api }

118

def mean_squared_error(

119

preds: Tensor,

120

target: Tensor,

121

squared: bool = True,

122

num_outputs: int = 1,

123

) -> Tensor: ...

124

125

def mean_absolute_error(

126

preds: Tensor,

127

target: Tensor,

128

num_outputs: int = 1,

129

) -> Tensor: ...

130

131

def r2_score(

132

preds: Tensor,

133

target: Tensor,

134

num_outputs: int = 1,

135

multioutput: str = "uniform_average",

136

adjusted: int = 0,

137

) -> Tensor: ...

138

139

def pearson_corrcoef(

140

preds: Tensor,

141

target: Tensor,

142

num_outputs: int = 1,

143

) -> Tensor: ...

144

145

def spearman_corrcoef(

146

preds: Tensor,

147

target: Tensor,

148

num_outputs: int = 1,

149

) -> Tensor: ...

150

151

def cosine_similarity(

152

preds: Tensor,

153

target: Tensor,

154

reduction: str = "sum",

155

) -> Tensor: ...

156

```

157

158

### Audio Functions

159

160

Functional audio quality and separation metrics.

161

162

```python { .api }

163

def scale_invariant_signal_distortion_ratio(

164

preds: Tensor,

165

target: Tensor,

166

zero_mean: bool = True,

167

) -> Tensor: ...

168

169

def signal_distortion_ratio(

170

preds: Tensor,

171

target: Tensor,

172

use_cg_iter: Optional[int] = None,

173

filter_length: int = 512,

174

zero_mean: bool = True,

175

load_diag: Optional[float] = None,

176

) -> Tensor: ...

177

178

def permutation_invariant_training(

179

preds: Tensor,

180

target: Tensor,

181

metric: Callable,

182

mode: str = "speaker-wise",

183

eval_func: str = "max",

184

) -> Tensor: ...

185

186

def perceptual_evaluation_speech_quality(

187

preds: Tensor,

188

target: Tensor,

189

fs: int,

190

mode: str = "wb",

191

) -> Tensor: ...

192

```

193

194

### Image Functions

195

196

Functional image quality assessment metrics.

197

198

```python { .api }

199

def peak_signal_noise_ratio(

200

preds: Tensor,

201

target: Tensor,

202

data_range: Optional[float] = None,

203

base: float = 10.0,

204

reduction: str = "elementwise_mean",

205

) -> Tensor: ...

206

207

def structural_similarity_index_measure(

208

preds: Tensor,

209

target: Tensor,

210

gaussian_kernel: bool = True,

211

sigma: Union[float, Tuple[float, float]] = 1.5,

212

kernel_size: Union[int, Tuple[int, int]] = 11,

213

reduction: str = "elementwise_mean",

214

data_range: Optional[float] = None,

215

k1: float = 0.01,

216

k2: float = 0.03,

217

) -> Tensor: ...

218

219

def multiscale_structural_similarity_index_measure(

220

preds: Tensor,

221

target: Tensor,

222

gaussian_kernel: bool = True,

223

sigma: Union[float, Tuple[float, float]] = 1.5,

224

kernel_size: Union[int, Tuple[int, int]] = 11,

225

reduction: str = "elementwise_mean",

226

data_range: Optional[float] = None,

227

k1: float = 0.01,

228

k2: float = 0.03,

229

betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),

230

normalize: Optional[str] = "relu",

231

) -> Tensor: ...

232

233

def universal_image_quality_index(

234

preds: Tensor,

235

target: Tensor,

236

kernel_size: Union[int, Tuple[int, int]] = 8,

237

sigma: Union[float, Tuple[float, float]] = 1.5,

238

reduction: str = "elementwise_mean",

239

) -> Tensor: ...

240

```

241

242

### Text Functions

243

244

Functional NLP metrics for text evaluation.

245

246

```python { .api }

247

def bleu_score(

248

preds: Sequence[str],

249

target: Sequence[Sequence[str]],

250

n_gram: int = 4,

251

smooth: bool = False,

252

weights: Optional[Sequence[float]] = None,

253

) -> Tensor: ...

254

255

def rouge_score(

256

preds: Union[str, Sequence[str]],

257

target: Union[str, Sequence[str], Sequence[Sequence[str]]],

258

rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"),

259

use_stemmer: bool = False,

260

normalizer: Optional[Callable[[str], str]] = None,

261

tokenizer: Optional[Callable[[str], Sequence[str]]] = None,

262

accumulate: str = "best",

263

) -> Dict[str, Tensor]: ...

264

265

def word_error_rate(

266

preds: Union[str, List[str]],

267

target: Union[str, List[str]],

268

) -> Tensor: ...

269

270

def char_error_rate(

271

preds: Union[str, List[str]],

272

target: Union[str, List[str]],

273

) -> Tensor: ...

274

275

def edit_distance(

276

preds: Union[str, List[str]],

277

target: Union[str, List[str]],

278

substitution_cost: int = 1,

279

reduction: Optional[str] = "mean",

280

) -> Tensor: ...

281

```

282

283

### Clustering Functions

284

285

Functional clustering evaluation metrics.

286

287

```python { .api }

288

def adjusted_rand_score(

289

preds: Tensor,

290

target: Tensor,

291

) -> Tensor: ...

292

293

def normalized_mutual_info_score(

294

preds: Tensor,

295

target: Tensor,

296

average: str = "arithmetic",

297

) -> Tensor: ...

298

299

def calinski_harabasz_score(

300

data: Tensor,

301

labels: Tensor,

302

) -> Tensor: ...

303

304

def davies_bouldin_score(

305

data: Tensor,

306

labels: Tensor,

307

) -> Tensor: ...

308

```

309

310

### Pairwise Functions

311

312

Functional pairwise distance and similarity measures.

313

314

```python { .api }

315

def pairwise_cosine_similarity(

316

x: Tensor,

317

y: Optional[Tensor] = None,

318

reduction: Optional[str] = None,

319

zero_diagonal: bool = True,

320

) -> Tensor: ...

321

322

def pairwise_euclidean_distance(

323

x: Tensor,

324

y: Optional[Tensor] = None,

325

reduction: Optional[str] = None,

326

zero_diagonal: bool = True,

327

) -> Tensor: ...

328

329

def pairwise_manhattan_distance(

330

x: Tensor,

331

y: Optional[Tensor] = None,

332

reduction: Optional[str] = None,

333

zero_diagonal: bool = True,

334

) -> Tensor: ...

335

336

def pairwise_minkowski_distance(

337

x: Tensor,

338

y: Optional[Tensor] = None,

339

p: float = 2.0,

340

reduction: Optional[str] = None,

341

zero_diagonal: bool = True,

342

) -> Tensor: ...

343

```

344

345

## Usage Examples

346

347

### Basic Functional Usage

348

349

```python

350

import torch

351

import torchmetrics.functional as F

352

353

# Binary classification

354

preds = torch.tensor([0.1, 0.9, 0.8, 0.4])

355

target = torch.tensor([0, 1, 1, 0])

356

357

# Compute metrics directly

358

acc = F.accuracy(preds, target, task="binary")

359

f1 = F.f1_score(preds, target, task="binary")

360

auc = F.auroc(preds, target, task="binary")

361

362

print(f"Accuracy: {acc:.4f}")

363

print(f"F1 Score: {f1:.4f}")

364

print(f"AUROC: {auc:.4f}")

365

```

366

367

### Multiclass Classification

368

369

```python

370

import torchmetrics.functional.classification as FC

371

372

# Multiclass predictions

373

preds = torch.randn(100, 5).softmax(dim=-1)

374

target = torch.randint(0, 5, (100,))

375

376

# Compute various metrics

377

acc = FC.multiclass_accuracy(preds, target, num_classes=5)

378

precision = FC.multiclass_precision(preds, target, num_classes=5, average="macro")

379

recall = FC.multiclass_recall(preds, target, num_classes=5, average="macro")

380

conf_matrix = FC.multiclass_confusion_matrix(preds, target, num_classes=5)

381

382

print(f"Accuracy: {acc:.4f}")

383

print(f"Macro Precision: {precision:.4f}")

384

print(f"Macro Recall: {recall:.4f}")

385

print(f"Confusion Matrix Shape: {conf_matrix.shape}")

386

```

387

388

### Regression Metrics

389

390

```python

391

import torchmetrics.functional.regression as FR

392

393

# Regression predictions

394

preds = torch.randn(50, 1)

395

target = torch.randn(50, 1)

396

397

# Compute regression metrics

398

mse = FR.mean_squared_error(preds, target)

399

mae = FR.mean_absolute_error(preds, target)

400

r2 = FR.r2_score(preds, target)

401

pearson = FR.pearson_corrcoef(preds, target)

402

403

print(f"MSE: {mse:.4f}")

404

print(f"MAE: {mae:.4f}")

405

print(f"R²: {r2:.4f}")

406

print(f"Pearson Correlation: {pearson:.4f}")

407

```

408

409

### Image Quality Assessment

410

411

```python

412

import torchmetrics.functional.image as FI

413

414

# Image tensors

415

preds = torch.rand(4, 3, 256, 256)

416

target = torch.rand(4, 3, 256, 256)

417

418

# Compute image quality metrics

419

psnr = FI.peak_signal_noise_ratio(preds, target)

420

ssim = FI.structural_similarity_index_measure(preds, target)

421

ms_ssim = FI.multiscale_structural_similarity_index_measure(preds, target)

422

423

print(f"PSNR: {psnr:.4f}")

424

print(f"SSIM: {ssim:.4f}")

425

print(f"MS-SSIM: {ms_ssim:.4f}")

426

```

427

428

### Text Evaluation

429

430

```python

431

import torchmetrics.functional.text as FT

432

433

# Text evaluation

434

preds = ["the cat is on the mat"]

435

target = [["there is a cat on the mat", "a cat is on the mat"]]

436

437

# Compute text metrics

438

bleu = FT.bleu_score(preds, target)

439

rouge_scores = FT.rouge_score(preds[0], target[0])

440

441

print(f"BLEU Score: {bleu:.4f}")

442

print(f"ROUGE-1 F1: {rouge_scores['rouge1_fmeasure']:.4f}")

443

print(f"ROUGE-L F1: {rouge_scores['rougeL_fmeasure']:.4f}")

444

445

# Error rates

446

pred_text = ["this is a test"]

447

target_text = ["this is the test"]

448

wer = FT.word_error_rate(pred_text, target_text)

449

cer = FT.char_error_rate(pred_text, target_text)

450

451

print(f"Word Error Rate: {wer:.4f}")

452

print(f"Character Error Rate: {cer:.4f}")

453

```

454

455

### Audio Quality

456

457

```python

458

import torchmetrics.functional.audio as FA

459

460

# Audio signals

461

preds = torch.randn(4, 8000) # 4 samples, 8000 time steps

462

target = torch.randn(4, 8000)

463

464

# Compute audio metrics

465

si_sdr = FA.scale_invariant_signal_distortion_ratio(preds, target)

466

si_snr = FA.scale_invariant_signal_noise_ratio(preds, target)

467

468

print(f"SI-SDR: {si_sdr:.4f} dB")

469

print(f"SI-SNR: {si_snr:.4f} dB")

470

```

471

472

### Pairwise Distances

473

474

```python

475

import torchmetrics.functional.pairwise as FP

476

477

# Feature vectors

478

x = torch.randn(100, 64) # 100 samples, 64-dim features

479

y = torch.randn(50, 64) # 50 samples, 64-dim features

480

481

# Compute pairwise similarities and distances

482

cosine_sim = FP.pairwise_cosine_similarity(x, y)

483

euclidean_dist = FP.pairwise_euclidean_distance(x, y)

484

manhattan_dist = FP.pairwise_manhattan_distance(x, y)

485

486

print(f"Cosine Similarity Shape: {cosine_sim.shape}") # (100, 50)

487

print(f"Euclidean Distance Shape: {euclidean_dist.shape}") # (100, 50)

488

print(f"Manhattan Distance Shape: {manhattan_dist.shape}") # (100, 50)

489

```

490

491

### Clustering Evaluation

492

493

```python

494

import torchmetrics.functional.clustering as FCL

495

496

# Clustering results

497

pred_clusters = torch.randint(0, 3, (100,))

498

true_clusters = torch.randint(0, 3, (100,))

499

500

# Clustering metrics

501

ari = FCL.adjusted_rand_score(pred_clusters, true_clusters)

502

nmi = FCL.normalized_mutual_info_score(pred_clusters, true_clusters)

503

504

print(f"Adjusted Rand Index: {ari:.4f}")

505

print(f"Normalized Mutual Info: {nmi:.4f}")

506

507

# Internal clustering metrics (require data)

508

data = torch.randn(100, 10) # 100 samples, 10 features

509

ch_score = FCL.calinski_harabasz_score(data, pred_clusters)

510

db_score = FCL.davies_bouldin_score(data, pred_clusters)

511

512

print(f"Calinski-Harabasz Score: {ch_score:.4f}")

513

print(f"Davies-Bouldin Score: {db_score:.4f}")

514

```

515

516

## Functional vs Class-based API

517

518

### When to Use Functional API

519

520

- One-off metric computations

521

- Custom training loops without Lightning

522

- Minimal memory overhead requirements

523

- Integration with existing codebases

524

- Research experiments requiring flexibility

525

526

### When to Use Class-based API

527

528

- Accumulating metrics across batches

529

- Distributed training scenarios

530

- PyTorch Lightning integration

531

- Automatic state management needed

532

- Complex metric tracking workflows

533

534

## Types

535

536

```python { .api }

537

from typing import Union, Optional, List, Dict, Tuple, Sequence, Callable, Any

538

import torch

539

from torch import Tensor

540

541

# Common functional types

542

FunctionalOutput = Union[Tensor, Dict[str, Tensor], Tuple[Tensor, ...]]

543

TaskType = Union["binary", "multiclass", "multilabel"]

544

AverageType = Union["micro", "macro", "weighted", "none", None]

545

ReductionType = Union["mean", "sum", "none", "elementwise_mean"]

546

```