or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio.mdclassification.mdclustering.mddetection.mdfunctional.mdimage.mdindex.mdmultimodal.mdnominal.mdregression.mdretrieval.mdsegmentation.mdshape.mdtext.mdutilities.mdvideo.md

classification.mddocs/

0

# Classification Metrics

1

2

Comprehensive classification metrics supporting binary, multiclass, and multilabel scenarios. All classification metrics support automatic task detection and provide consistent APIs across different classification types with variants for each task type.

3

4

## Capabilities

5

6

### Accuracy Metrics

7

8

Measures the proportion of correct predictions among all predictions made.

9

10

```python { .api }

11

class Accuracy(Metric):

12

def __init__(

13

self,

14

task: str,

15

threshold: float = 0.5,

16

num_classes: Optional[int] = None,

17

num_labels: Optional[int] = None,

18

average: Optional[str] = "micro",

19

multidim_average: str = "global",

20

top_k: Optional[int] = None,

21

ignore_index: Optional[int] = None,

22

validate_args: bool = True,

23

**kwargs

24

): ...

25

26

class BinaryAccuracy(Metric):

27

def __init__(

28

self,

29

threshold: float = 0.5,

30

multidim_average: str = "global",

31

ignore_index: Optional[int] = None,

32

validate_args: bool = True,

33

**kwargs

34

): ...

35

36

class MulticlassAccuracy(Metric):

37

def __init__(

38

self,

39

num_classes: int,

40

average: Optional[str] = "micro",

41

top_k: Optional[int] = None,

42

multidim_average: str = "global",

43

ignore_index: Optional[int] = None,

44

validate_args: bool = True,

45

**kwargs

46

): ...

47

48

class MultilabelAccuracy(Metric):

49

def __init__(

50

self,

51

num_labels: int,

52

threshold: float = 0.5,

53

average: Optional[str] = "micro",

54

multidim_average: str = "global",

55

ignore_index: Optional[int] = None,

56

validate_args: bool = True,

57

**kwargs

58

): ...

59

```

60

61

### Area Under ROC Curve (AUROC)

62

63

Computes Area Under the Receiver Operating Characteristic Curve, measuring the model's ability to distinguish between classes.

64

65

```python { .api }

66

class AUROC(Metric):

67

def __init__(

68

self,

69

task: str,

70

num_classes: Optional[int] = None,

71

num_labels: Optional[int] = None,

72

average: Optional[str] = "macro",

73

max_fpr: Optional[float] = None,

74

thresholds: Optional[Union[int, List[float], Tensor]] = None,

75

ignore_index: Optional[int] = None,

76

validate_args: bool = True,

77

**kwargs

78

): ...

79

80

class BinaryAUROC(Metric):

81

def __init__(

82

self,

83

max_fpr: Optional[float] = None,

84

thresholds: Optional[Union[int, List[float], Tensor]] = None,

85

ignore_index: Optional[int] = None,

86

validate_args: bool = True,

87

**kwargs

88

): ...

89

90

class MulticlassAUROC(Metric):

91

def __init__(

92

self,

93

num_classes: int,

94

average: Optional[str] = "macro",

95

thresholds: Optional[Union[int, List[float], Tensor]] = None,

96

ignore_index: Optional[int] = None,

97

validate_args: bool = True,

98

**kwargs

99

): ...

100

101

class MultilabelAUROC(Metric):

102

def __init__(

103

self,

104

num_labels: int,

105

average: Optional[str] = "macro",

106

thresholds: Optional[Union[int, List[float], Tensor]] = None,

107

ignore_index: Optional[int] = None,

108

validate_args: bool = True,

109

**kwargs

110

): ...

111

```

112

113

### ROC Curves

114

115

Computes Receiver Operating Characteristic curves for visualization and analysis.

116

117

```python { .api }

118

class ROC(Metric):

119

def __init__(

120

self,

121

task: str,

122

num_classes: Optional[int] = None,

123

num_labels: Optional[int] = None,

124

thresholds: Optional[Union[int, List[float], Tensor]] = None,

125

ignore_index: Optional[int] = None,

126

validate_args: bool = True,

127

**kwargs

128

): ...

129

130

class BinaryROC(Metric):

131

def __init__(

132

self,

133

thresholds: Optional[Union[int, List[float], Tensor]] = None,

134

ignore_index: Optional[int] = None,

135

validate_args: bool = True,

136

**kwargs

137

): ...

138

139

class MulticlassROC(Metric):

140

def __init__(

141

self,

142

num_classes: int,

143

thresholds: Optional[Union[int, List[float], Tensor]] = None,

144

ignore_index: Optional[int] = None,

145

validate_args: bool = True,

146

**kwargs

147

): ...

148

149

class MultilabelROC(Metric):

150

def __init__(

151

self,

152

num_labels: int,

153

thresholds: Optional[Union[int, List[float], Tensor]] = None,

154

ignore_index: Optional[int] = None,

155

validate_args: bool = True,

156

**kwargs

157

): ...

158

```

159

160

### Precision and Recall

161

162

Measures the proportion of relevant instances among retrieved instances (precision) and retrieved instances among relevant instances (recall).

163

164

```python { .api }

165

class Precision(Metric):

166

def __init__(

167

self,

168

task: str,

169

threshold: float = 0.5,

170

num_classes: Optional[int] = None,

171

num_labels: Optional[int] = None,

172

average: Optional[str] = "micro",

173

multidim_average: str = "global",

174

top_k: Optional[int] = None,

175

ignore_index: Optional[int] = None,

176

validate_args: bool = True,

177

**kwargs

178

): ...

179

180

class Recall(Metric):

181

def __init__(

182

self,

183

task: str,

184

threshold: float = 0.5,

185

num_classes: Optional[int] = None,

186

num_labels: Optional[int] = None,

187

average: Optional[str] = "micro",

188

multidim_average: str = "global",

189

top_k: Optional[int] = None,

190

ignore_index: Optional[int] = None,

191

validate_args: bool = True,

192

**kwargs

193

): ...

194

```

195

196

Each precision and recall metric also has Binary, Multiclass, and Multilabel variants with task-specific parameters.

197

198

### F-Scores

199

200

Harmonic mean of precision and recall, with F1 being the most commonly used (beta=1).

201

202

```python { .api }

203

class F1Score(Metric):

204

def __init__(

205

self,

206

task: str,

207

threshold: float = 0.5,

208

num_classes: Optional[int] = None,

209

num_labels: Optional[int] = None,

210

average: Optional[str] = "micro",

211

multidim_average: str = "global",

212

top_k: Optional[int] = None,

213

ignore_index: Optional[int] = None,

214

validate_args: bool = True,

215

**kwargs

216

): ...

217

218

class FBetaScore(Metric):

219

def __init__(

220

self,

221

task: str,

222

beta: float = 1.0,

223

threshold: float = 0.5,

224

num_classes: Optional[int] = None,

225

num_labels: Optional[int] = None,

226

average: Optional[str] = "micro",

227

multidim_average: str = "global",

228

top_k: Optional[int] = None,

229

ignore_index: Optional[int] = None,

230

validate_args: bool = True,

231

**kwargs

232

): ...

233

```

234

235

### Average Precision

236

237

Computes average precision score, which summarizes a precision-recall curve as the weighted mean of precisions.

238

239

```python { .api }

240

class AveragePrecision(Metric):

241

def __init__(

242

self,

243

task: str,

244

num_classes: Optional[int] = None,

245

num_labels: Optional[int] = None,

246

average: Optional[str] = "macro",

247

thresholds: Optional[Union[int, List[float], Tensor]] = None,

248

ignore_index: Optional[int] = None,

249

validate_args: bool = True,

250

**kwargs

251

): ...

252

```

253

254

### Confusion Matrix

255

256

Computes confusion matrix for evaluating classification accuracy with detailed breakdown of true/false positives and negatives.

257

258

```python { .api }

259

class ConfusionMatrix(Metric):

260

def __init__(

261

self,

262

task: str,

263

num_classes: int,

264

threshold: float = 0.5,

265

num_labels: Optional[int] = None,

266

normalize: Optional[str] = None,

267

ignore_index: Optional[int] = None,

268

validate_args: bool = True,

269

**kwargs

270

): ...

271

```

272

273

### Statistical Scores

274

275

Computes true positives, false positives, true negatives, false negatives, and support statistics.

276

277

```python { .api }

278

class StatScores(Metric):

279

def __init__(

280

self,

281

task: str,

282

threshold: float = 0.5,

283

num_classes: Optional[int] = None,

284

num_labels: Optional[int] = None,

285

average: Optional[str] = "micro",

286

multidim_average: str = "global",

287

top_k: Optional[int] = None,

288

ignore_index: Optional[int] = None,

289

validate_args: bool = True,

290

**kwargs

291

): ...

292

```

293

294

### Threshold-Based Metrics

295

296

Metrics that find optimal thresholds or evaluate performance at specific operating points.

297

298

```python { .api }

299

class PrecisionAtFixedRecall(Metric):

300

def __init__(

301

self,

302

task: str,

303

min_recall: float,

304

num_classes: Optional[int] = None,

305

num_labels: Optional[int] = None,

306

thresholds: Optional[Union[int, List[float], Tensor]] = None,

307

ignore_index: Optional[int] = None,

308

validate_args: bool = True,

309

**kwargs

310

): ...

311

312

class RecallAtFixedPrecision(Metric):

313

def __init__(

314

self,

315

task: str,

316

min_precision: float,

317

num_classes: Optional[int] = None,

318

num_labels: Optional[int] = None,

319

thresholds: Optional[Union[int, List[float], Tensor]] = None,

320

ignore_index: Optional[int] = None,

321

validate_args: bool = True,

322

**kwargs

323

): ...

324

325

class SensitivityAtSpecificity(Metric):

326

def __init__(

327

self,

328

task: str,

329

min_specificity: float,

330

num_classes: Optional[int] = None,

331

num_labels: Optional[int] = None,

332

thresholds: Optional[Union[int, List[float], Tensor]] = None,

333

ignore_index: Optional[int] = None,

334

validate_args: bool = True,

335

**kwargs

336

): ...

337

338

class SpecificityAtSensitivity(Metric):

339

def __init__(

340

self,

341

task: str,

342

min_sensitivity: float,

343

num_classes: Optional[int] = None,

344

num_labels: Optional[int] = None,

345

thresholds: Optional[Union[int, List[float], Tensor]] = None,

346

ignore_index: Optional[int] = None,

347

validate_args: bool = True,

348

**kwargs

349

): ...

350

```

351

352

### Advanced Classification Metrics

353

354

Specialized metrics for specific classification scenarios.

355

356

```python { .api }

357

class CohenKappa(Metric):

358

def __init__(

359

self,

360

task: str,

361

num_classes: int,

362

threshold: float = 0.5,

363

num_labels: Optional[int] = None,

364

weights: Optional[str] = None,

365

ignore_index: Optional[int] = None,

366

validate_args: bool = True,

367

**kwargs

368

): ...

369

370

class MatthewsCorrCoef(Metric):

371

def __init__(

372

self,

373

task: str,

374

threshold: float = 0.5,

375

num_classes: Optional[int] = None,

376

num_labels: Optional[int] = None,

377

ignore_index: Optional[int] = None,

378

validate_args: bool = True,

379

**kwargs

380

): ...

381

382

class JaccardIndex(Metric):

383

def __init__(

384

self,

385

task: str,

386

threshold: float = 0.5,

387

num_classes: Optional[int] = None,

388

num_labels: Optional[int] = None,

389

average: Optional[str] = "micro",

390

ignore_index: Optional[int] = None,

391

validate_args: bool = True,

392

**kwargs

393

): ...

394

395

class HammingDistance(Metric):

396

def __init__(

397

self,

398

task: str,

399

threshold: float = 0.5,

400

num_classes: Optional[int] = None,

401

num_labels: Optional[int] = None,

402

multidim_average: str = "global",

403

ignore_index: Optional[int] = None,

404

validate_args: bool = True,

405

**kwargs

406

): ...

407

408

class ExactMatch(Metric):

409

def __init__(

410

self,

411

task: str,

412

threshold: float = 0.5,

413

num_classes: Optional[int] = None,

414

num_labels: Optional[int] = None,

415

multidim_average: str = "global",

416

ignore_index: Optional[int] = None,

417

validate_args: bool = True,

418

**kwargs

419

): ...

420

```

421

422

### Calibration and Ranking Metrics

423

424

Metrics for evaluating model calibration and ranking quality.

425

426

```python { .api }

427

class CalibrationError(Metric):

428

def __init__(

429

self,

430

task: str,

431

n_bins: int = 15,

432

norm: str = "l1",

433

num_classes: Optional[int] = None,

434

num_labels: Optional[int] = None,

435

ignore_index: Optional[int] = None,

436

validate_args: bool = True,

437

**kwargs

438

): ...

439

440

class MultilabelRankingAveragePrecision(Metric):

441

def __init__(

442

self,

443

num_labels: int,

444

validate_args: bool = True,

445

**kwargs

446

): ...

447

448

class MultilabelRankingLoss(Metric):

449

def __init__(

450

self,

451

num_labels: int,

452

validate_args: bool = True,

453

**kwargs

454

): ...

455

456

class MultilabelCoverageError(Metric):

457

def __init__(

458

self,

459

num_labels: int,

460

validate_args: bool = True,

461

**kwargs

462

): ...

463

```

464

465

## Usage Examples

466

467

### Basic Classification

468

469

```python

470

import torch

471

from torchmetrics import Accuracy, F1Score, ConfusionMatrix

472

473

# Binary classification

474

binary_acc = Accuracy(task="binary")

475

preds = torch.tensor([0.1, 0.9, 0.8, 0.4])

476

target = torch.tensor([0, 1, 1, 0])

477

print(binary_acc(preds, target))

478

479

# Multiclass classification

480

multiclass_f1 = F1Score(task="multiclass", num_classes=3, average="macro")

481

preds = torch.randn(10, 3).softmax(dim=-1)

482

target = torch.randint(0, 3, (10,))

483

print(multiclass_f1(preds, target))

484

485

# Multilabel classification

486

multilabel_cm = ConfusionMatrix(task="multilabel", num_labels=3)

487

preds = torch.randn(10, 3).sigmoid()

488

target = torch.randint(0, 2, (10, 3))

489

print(multilabel_cm(preds, target))

490

```

491

492

### Threshold-based Metrics

493

494

```python

495

from torchmetrics import PrecisionAtFixedRecall, ROC

496

497

# Find precision at 90% recall

498

precision_at_recall = PrecisionAtFixedRecall(task="binary", min_recall=0.9)

499

preds = torch.randn(100).sigmoid()

500

target = torch.randint(0, 2, (100,))

501

precision_value, threshold = precision_at_recall(preds, target)

502

print(f"Precision: {precision_value:.3f} at threshold: {threshold:.3f}")

503

504

# Compute ROC curve

505

roc = ROC(task="binary")

506

fpr, tpr, thresholds = roc(preds, target)

507

```

508

509

## Types

510

511

```python { .api }

512

TaskType = Union["binary", "multiclass", "multilabel"]

513

AverageType = Union["micro", "macro", "weighted", "none", None]

514

MDMCAverageType = Union["global", "samplewise"]

515

ThresholdType = Union[float, List[float], Tensor]

516

```