or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio.mdclassification.mdclustering.mddetection.mdfunctional.mdimage.mdindex.mdmultimodal.mdnominal.mdregression.mdretrieval.mdsegmentation.mdshape.mdtext.mdutilities.mdvideo.md

index.mddocs/

0

# TorchMetrics

1

2

A comprehensive metrics library for PyTorch and PyTorch Lightning, providing 400+ rigorously tested metrics across classification, regression, audio, image, text, and other machine learning domains. TorchMetrics offers distributed and scalable metric computation with consistent APIs, automatic device handling, and seamless integration with PyTorch workflows.

3

4

## Package Information

5

6

- **Package Name**: torchmetrics

7

- **Package Type**: Library

8

- **Language**: Python

9

- **Installation**: `pip install torchmetrics`

10

11

## Core Imports

12

13

```python

14

import torchmetrics

15

```

16

17

For functional API:

18

```python

19

import torchmetrics.functional as F

20

```

21

22

For specific metrics:

23

```python

24

from torchmetrics import Accuracy, AUROC, F1Score

25

from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy

26

from torchmetrics.regression import MeanSquaredError, R2Score

27

```

28

29

## Basic Usage

30

31

```python

32

import torch

33

import torchmetrics

34

35

# Initialize metrics

36

accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)

37

f1 = torchmetrics.F1Score(task="multiclass", num_classes=3)

38

39

# Create sample predictions and targets

40

preds = torch.randn(10, 3).softmax(dim=-1)

41

target = torch.randint(0, 3, (10,))

42

43

# Compute metrics

44

acc_score = accuracy(preds, target)

45

f1_score = f1(preds, target)

46

47

print(f"Accuracy: {acc_score:.4f}")

48

print(f"F1 Score: {f1_score:.4f}")

49

50

# Using functional API

51

from torchmetrics.functional import accuracy, f1_score as f1_func

52

53

acc_functional = accuracy(preds, target, task="multiclass", num_classes=3)

54

f1_functional = f1_func(preds, target, task="multiclass", num_classes=3)

55

56

print(f"Functional Accuracy: {acc_functional:.4f}")

57

print(f"Functional F1: {f1_functional:.4f}")

58

```

59

60

## Architecture

61

62

TorchMetrics follows a dual-interface design pattern:

63

64

- **Modular (Class-based) Interface**: Stateful metric classes that accumulate values across batches and provide automatic synchronization in distributed settings

65

- **Functional Interface**: Stateless functions for single-batch computations and one-off metric calculations

66

- **MetricCollection**: Container for organizing and computing multiple metrics simultaneously

67

- **Wrappers**: Advanced functionality for bootstrapping, per-class computation, and multi-task scenarios

68

69

All metrics inherit from the base `Metric` class, ensuring consistent behavior, automatic device handling, state management, and distributed computation support across the entire library.

70

71

## Capabilities

72

73

### Classification Metrics

74

75

Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. Includes accuracy, precision, recall, F-scores, ROC/AUC, confusion matrices, and threshold-based metrics.

76

77

```python { .api }

78

class Accuracy(Metric):

79

def __init__(self, task: str, num_classes: int = None, **kwargs): ...

80

81

class AUROC(Metric):

82

def __init__(self, task: str, num_classes: int = None, **kwargs): ...

83

84

class F1Score(Metric):

85

def __init__(self, task: str, num_classes: int = None, **kwargs): ...

86

87

class Precision(Metric):

88

def __init__(self, task: str, num_classes: int = None, **kwargs): ...

89

90

class Recall(Metric):

91

def __init__(self, task: str, num_classes: int = None, **kwargs): ...

92

93

class ConfusionMatrix(Metric):

94

def __init__(self, task: str, num_classes: int, **kwargs): ...

95

```

96

97

[Classification](./classification.md)

98

99

### Regression Metrics

100

101

Metrics for regression tasks including error measurements, correlation coefficients, and explained variance measures for continuous target prediction evaluation.

102

103

```python { .api }

104

class MeanSquaredError(Metric):

105

def __init__(self, **kwargs): ...

106

107

class MeanAbsoluteError(Metric):

108

def __init__(self, **kwargs): ...

109

110

class R2Score(Metric):

111

def __init__(self, num_outputs: int = 1, **kwargs): ...

112

113

class PearsonCorrCoef(Metric):

114

def __init__(self, num_outputs: int = 1, **kwargs): ...

115

```

116

117

[Regression](./regression.md)

118

119

### Audio Metrics

120

121

Specialized metrics for audio processing and speech evaluation including signal-to-noise ratios, perceptual quality measures, and separation metrics.

122

123

```python { .api }

124

class ScaleInvariantSignalDistortionRatio(Metric):

125

def __init__(self, **kwargs): ...

126

127

class PermutationInvariantTraining(Metric):

128

def __init__(self, metric, mode: str = "speaker-wise", **kwargs): ...

129

130

class PerceptualEvaluationSpeechQuality(Metric):

131

def __init__(self, fs: int, mode: str = "wb", **kwargs): ...

132

```

133

134

[Audio](./audio.md)

135

136

### Image Quality Metrics

137

138

Image quality assessment metrics including structural similarity, peak signal-to-noise ratio, and perceptual quality measures for computer vision applications.

139

140

```python { .api }

141

class StructuralSimilarityIndexMeasure(Metric):

142

def __init__(self, **kwargs): ...

143

144

class PeakSignalNoiseRatio(Metric):

145

def __init__(self, **kwargs): ...

146

147

class FrechetInceptionDistance(Metric):

148

def __init__(self, feature: int = 2048, **kwargs): ...

149

```

150

151

[Image](./image.md)

152

153

### Text Metrics

154

155

Natural language processing metrics for translation, summarization, and text generation evaluation including BLEU, ROUGE, and semantic similarity measures.

156

157

```python { .api }

158

class BLEUScore(Metric):

159

def __init__(self, n_gram: int = 4, **kwargs): ...

160

161

class ROUGEScore(Metric):

162

def __init__(self, rouge_keys: Union[str, Tuple[str, ...]] = ("rouge1", "rouge2", "rougeL"), **kwargs): ...

163

164

class BERTScore(Metric):

165

def __init__(self, model_name_or_path: str = "distilbert-base-uncased", **kwargs): ...

166

```

167

168

[Text](./text.md)

169

170

### Detection Metrics

171

172

Object detection and instance segmentation metrics for evaluating bounding box predictions, IoU calculations, and mean average precision.

173

174

```python { .api }

175

class MeanAveragePrecision(Metric):

176

def __init__(self, **kwargs): ...

177

178

class IntersectionOverUnion(Metric):

179

def __init__(self, **kwargs): ...

180

181

class PanopticQuality(Metric):

182

def __init__(self, **kwargs): ...

183

```

184

185

[Detection](./detection.md)

186

187

### Clustering Metrics

188

189

Unsupervised learning evaluation metrics including mutual information, Rand indices, and silhouette analysis for cluster quality assessment.

190

191

```python { .api }

192

class AdjustedRandScore(Metric):

193

def __init__(self, **kwargs): ...

194

195

class NormalizedMutualInfoScore(Metric):

196

def __init__(self, average: str = "arithmetic", **kwargs): ...

197

198

class CalinskiHarabaszScore(Metric):

199

def __init__(self, **kwargs): ...

200

```

201

202

[Clustering](./clustering.md)

203

204

### Information Retrieval Metrics

205

206

Metrics for ranking and retrieval systems including precision at k, mean average precision, and normalized discounted cumulative gain.

207

208

```python { .api }

209

class RetrievalMAP(Metric):

210

def __init__(self, **kwargs): ...

211

212

class RetrievalNormalizedDCG(Metric):

213

def __init__(self, k: int = None, **kwargs): ...

214

215

class RetrievalMRR(Metric):

216

def __init__(self, **kwargs): ...

217

```

218

219

[Retrieval](./retrieval.md)

220

221

### Segmentation Metrics

222

223

Semantic and instance segmentation evaluation including Dice coefficients, Intersection over Union, and Hausdorff distance for pixel-level predictions.

224

225

```python { .api }

226

class DiceScore(Metric):

227

def __init__(self, **kwargs): ...

228

229

class MeanIoU(Metric):

230

def __init__(self, num_classes: int, **kwargs): ...

231

232

class HausdorffDistance(Metric):

233

def __init__(self, **kwargs): ...

234

```

235

236

[Segmentation](./segmentation.md)

237

238

### Multimodal Metrics

239

240

Metrics for evaluating multimodal AI systems including video-audio synchronization and cross-modal quality assessment.

241

242

```python { .api }

243

class LipVertexError(Metric):

244

def __init__(self, **kwargs): ...

245

246

class CLIPScore(Metric):

247

def __init__(self, model_name_or_path: str = "openai/clip-vit-base-patch16", **kwargs): ...

248

```

249

250

[Multimodal](./multimodal.md)

251

252

### Nominal/Categorical Metrics

253

254

Statistical measures for analyzing associations and agreements between categorical variables.

255

256

```python { .api }

257

class CramersV(Metric):

258

def __init__(self, num_classes: int, **kwargs): ...

259

260

class FleissKappa(Metric):

261

def __init__(self, mode: str = "counts", **kwargs): ...

262

```

263

264

[Nominal](./nominal.md)

265

266

### Shape Metrics

267

268

Metrics for analyzing geometric shapes and spatial configurations.

269

270

```python { .api }

271

class ProcrustesDisparity(Metric):

272

def __init__(self, **kwargs): ...

273

```

274

275

[Shape](./shape.md)

276

277

### Video Metrics

278

279

Specialized metrics for video quality assessment and evaluation.

280

281

```python { .api }

282

class VideoMultiMethodAssessmentFusion(Metric):

283

def __init__(self, **kwargs): ...

284

```

285

286

[Video](./video.md)

287

288

### Aggregation and Utilities

289

290

Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics and bootstrapping.

291

292

```python { .api }

293

class MetricCollection:

294

def __init__(self, metrics: Union[Dict[str, Metric], List[Metric]], **kwargs): ...

295

296

class MeanMetric(Metric):

297

def __init__(self, **kwargs): ...

298

299

class SumMetric(Metric):

300

def __init__(self, **kwargs): ...

301

302

class BootStrapper:

303

def __init__(self, base_metric: Metric, num_bootstraps: int = 100, **kwargs): ...

304

```

305

306

[Utilities](./utilities.md)

307

308

### Functional API

309

310

Direct functional implementations of all metrics for stateless computation, providing immediate results without object instantiation or state management.

311

312

```python { .api }

313

def accuracy(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...

314

def f1_score(preds: Tensor, target: Tensor, task: str, **kwargs) -> Tensor: ...

315

def mean_squared_error(preds: Tensor, target: Tensor) -> Tensor: ...

316

def structural_similarity_index_measure(preds: Tensor, target: Tensor, **kwargs) -> Tensor: ...

317

```

318

319

[Functional](./functional.md)

320

321

## Types

322

323

Core imports for TorchMetrics:

324

325

```python

326

from typing import Union, Optional, Tuple, Dict, List, Any, Callable

327

import torch

328

from torch import Tensor

329

```

330

331

Common type aliases:

332

333

```python { .api }

334

TaskType = Union["binary", "multiclass", "multilabel"]

335

AverageType = Union["micro", "macro", "weighted", "none", None]

336

MDMCAverageType = Union["global", "samplewise"]

337

ThresholdType = Union[float, List[float], Tensor]

338

```

339

340

Base metric class:

341

342

```python { .api }

343

class Metric:

344

"""Base class for all metrics."""

345

def __init__(self, **kwargs): ...

346

def __call__(self, *args, **kwargs) -> Any: ...

347

def update(self, *args, **kwargs) -> None: ...

348

def compute(self) -> Any: ...

349

def reset(self) -> None: ...

350

def to(self, device: Union[str, torch.device]) -> "Metric": ...

351

def forward(self, *args, **kwargs) -> Any: ...

352

```