or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

audio.mdclassification.mdclustering.mddetection.mdfunctional.mdimage.mdindex.mdmultimodal.mdnominal.mdregression.mdretrieval.mdsegmentation.mdshape.mdtext.mdutilities.mdvideo.md

utilities.mddocs/

0

# Utilities and Aggregation

1

2

Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics, bootstrapping, and metric collections for comprehensive evaluation workflows.

3

4

## Capabilities

5

6

### Metric Collection

7

8

Container for organizing and computing multiple metrics simultaneously with automatic synchronization and state management.

9

10

```python { .api }

11

class MetricCollection:

12

def __init__(

13

self,

14

metrics: Union[Dict[str, Metric], List[Metric], Tuple[Metric, ...]],

15

prefix: Optional[str] = None,

16

postfix: Optional[str] = None,

17

compute_groups: Union[bool, List[List[str]]] = True,

18

**kwargs

19

): ...

20

21

def __call__(self, *args, **kwargs) -> Dict[str, Any]: ...

22

def update(self, *args, **kwargs) -> None: ...

23

def compute(self) -> Dict[str, Any]: ...

24

def reset(self) -> None: ...

25

def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": ...

26

```

27

28

### Aggregation Metrics

29

30

Basic metrics for accumulating and aggregating values across batches and distributed processes.

31

32

```python { .api }

33

class MeanMetric(Metric):

34

def __init__(

35

self,

36

nan_strategy: str = "warn",

37

**kwargs

38

): ...

39

40

class SumMetric(Metric):

41

def __init__(

42

self,

43

nan_strategy: str = "warn",

44

**kwargs

45

): ...

46

47

class MaxMetric(Metric):

48

def __init__(

49

self,

50

nan_strategy: str = "warn",

51

**kwargs

52

): ...

53

54

class MinMetric(Metric):

55

def __init__(

56

self,

57

nan_strategy: str = "warn",

58

**kwargs

59

): ...

60

61

class CatMetric(Metric):

62

def __init__(

63

self,

64

nan_strategy: str = "warn",

65

**kwargs

66

): ...

67

```

68

69

### Running Statistics

70

71

Metrics that maintain running statistics over streaming data without storing all values.

72

73

```python { .api }

74

class RunningMean(Metric):

75

def __init__(

76

self,

77

window: int = 100,

78

**kwargs

79

): ...

80

81

class RunningSum(Metric):

82

def __init__(

83

self,

84

window: int = 100,

85

**kwargs

86

): ...

87

```

88

89

### Metric Wrappers

90

91

Advanced wrappers that enhance metric functionality with additional capabilities.

92

93

```python { .api }

94

class BootStrapper:

95

def __init__(

96

self,

97

base_metric: Union[Metric, Callable],

98

num_bootstraps: int = 100,

99

mean: bool = True,

100

std: bool = True,

101

raw: bool = False,

102

quantile: Optional[Union[float, Tensor]] = None,

103

sampling_strategy: str = "poisson",

104

**kwargs

105

): ...

106

107

class ClasswiseWrapper:

108

def __init__(

109

self,

110

metric: Metric,

111

labels: Optional[List[str]] = None,

112

**kwargs

113

): ...

114

115

class MetricTracker:

116

def __init__(

117

self,

118

metric: Metric,

119

maximize: bool = True,

120

**kwargs

121

): ...

122

123

class MinMaxMetric:

124

def __init__(

125

self,

126

base_metric: Metric,

127

**kwargs

128

): ...

129

130

class MultioutputWrapper:

131

def __init__(

132

self,

133

metric: Metric,

134

num_outputs: int,

135

**kwargs

136

): ...

137

138

class MultitaskWrapper:

139

def __init__(

140

self,

141

task_metrics: Dict[str, Metric],

142

**kwargs

143

): ...

144

```

145

146

### Advanced Wrappers

147

148

Specialized wrappers for complex metric computation scenarios.

149

150

```python { .api }

151

class FeatureShare:

152

def __init__(

153

self,

154

metric: Metric,

155

reset_real_features: bool = True,

156

**kwargs

157

): ...

158

159

class LambdaInputTransformer:

160

def __init__(

161

self,

162

metric: Metric,

163

transform_func: Callable,

164

transform_labels: bool = True,

165

**kwargs

166

): ...

167

168

class MetricInputTransformer:

169

def __init__(

170

self,

171

metric: Metric,

172

**kwargs

173

): ...

174

175

class Running:

176

def __init__(

177

self,

178

base_metric: Metric,

179

window_size: int = 100,

180

**kwargs

181

): ...

182

183

class BinaryTargetTransformer:

184

def __init__(

185

self,

186

metric: Metric,

187

target_transform: Callable[[Tensor], Tensor],

188

**kwargs

189

): ...

190

```

191

192

## Usage Examples

193

194

### Basic Metric Collection

195

196

```python

197

import torch

198

from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall

199

200

# Create a collection of classification metrics

201

metric_collection = MetricCollection({

202

'accuracy': Accuracy(task="multiclass", num_classes=3),

203

'f1': F1Score(task="multiclass", num_classes=3),

204

'precision': Precision(task="multiclass", num_classes=3),

205

'recall': Recall(task="multiclass", num_classes=3)

206

})

207

208

# Sample data

209

preds = torch.randn(32, 3).softmax(dim=-1)

210

target = torch.randint(0, 3, (32,))

211

212

# Compute all metrics at once

213

results = metric_collection(preds, target)

214

for metric_name, value in results.items():

215

print(f"{metric_name}: {value:.4f}")

216

```

217

218

### Aggregation Metrics

219

220

```python

221

from torchmetrics import MeanMetric, SumMetric, MaxMetric

222

223

# Initialize aggregation metrics

224

mean_loss = MeanMetric()

225

total_samples = SumMetric()

226

max_confidence = MaxMetric()

227

228

# Accumulate values across batches

229

for batch_idx in range(10):

230

batch_loss = torch.rand(1) * 2 # Random loss

231

batch_size = torch.tensor(32.0) # Batch size

232

batch_max_conf = torch.rand(1) # Max confidence in batch

233

234

mean_loss.update(batch_loss)

235

total_samples.update(batch_size)

236

max_confidence.update(batch_max_conf)

237

238

# Get final aggregated values

239

print(f"Mean Loss: {mean_loss.compute():.4f}")

240

print(f"Total Samples: {total_samples.compute():.0f}")

241

print(f"Max Confidence: {max_confidence.compute():.4f}")

242

```

243

244

### Bootstrapping for Confidence Intervals

245

246

```python

247

from torchmetrics.wrappers import BootStrapper

248

from torchmetrics import Accuracy

249

250

# Bootstrap accuracy for confidence intervals

251

base_accuracy = Accuracy(task="binary")

252

bootstrap_accuracy = BootStrapper(

253

base_accuracy,

254

num_bootstraps=1000,

255

mean=True,

256

std=True,

257

quantile=torch.tensor([0.025, 0.975]) # 95% confidence interval

258

)

259

260

# Sample binary classification data

261

preds = torch.rand(100)

262

target = torch.randint(0, 2, (100,))

263

264

# Compute bootstrapped statistics

265

bootstrap_results = bootstrap_accuracy(preds, target)

266

print(f"Mean Accuracy: {bootstrap_results['mean']:.4f}")

267

print(f"Std Accuracy: {bootstrap_results['std']:.4f}")

268

print(f"95% Confidence Interval: [{bootstrap_results['quantile'][0]:.4f}, {bootstrap_results['quantile'][1]:.4f}]")

269

```

270

271

### Per-Class Metrics

272

273

```python

274

from torchmetrics.wrappers import ClasswiseWrapper

275

from torchmetrics import F1Score

276

277

# Compute F1 score per class

278

class_labels = ['cat', 'dog', 'bird']

279

base_f1 = F1Score(task="multiclass", num_classes=3, average=None)

280

classwise_f1 = ClasswiseWrapper(base_f1, labels=class_labels)

281

282

# Sample multiclass data

283

preds = torch.randn(100, 3).softmax(dim=-1)

284

target = torch.randint(0, 3, (100,))

285

286

# Get per-class results

287

classwise_results = classwise_f1(preds, target)

288

for class_name, f1_score in classwise_results.items():

289

print(f"F1 Score for {class_name}: {f1_score:.4f}")

290

```

291

292

### Metric Tracking

293

294

```python

295

from torchmetrics.wrappers import MetricTracker

296

from torchmetrics import Accuracy

297

298

# Track best accuracy over time

299

tracker = MetricTracker(Accuracy(task="binary"), maximize=True)

300

301

# Simulate training epochs

302

accuracies = [0.6, 0.7, 0.65, 0.8, 0.75, 0.85, 0.82]

303

304

for epoch, acc in enumerate(accuracies):

305

preds = torch.rand(100)

306

target = torch.randint(0, 2, (100,))

307

308

# Update tracker (automatically keeps best)

309

result = tracker(preds, target)

310

311

print(f"Best Accuracy: {tracker.best_metric:.4f}")

312

print(f"Best Accuracy at Epoch: {tracker.best_step}")

313

```

314

315

### Running Statistics

316

317

```python

318

from torchmetrics import RunningMean

319

320

# Running mean with sliding window

321

running_mean = RunningMean(window=50)

322

323

# Simulate streaming data

324

for i in range(200):

325

value = torch.tensor(float(i + torch.randn(1) * 0.1))

326

running_mean.update(value)

327

328

if i % 50 == 0:

329

print(f"Step {i}: Running Mean = {running_mean.compute():.2f}")

330

```

331

332

### Multi-output Wrapper

333

334

```python

335

from torchmetrics.wrappers import MultioutputWrapper

336

from torchmetrics import MeanSquaredError

337

338

# MSE for multi-output regression

339

multi_mse = MultioutputWrapper(MeanSquaredError(), num_outputs=3)

340

341

# Multi-output predictions and targets

342

preds = torch.randn(50, 3) # 50 samples, 3 outputs

343

target = torch.randn(50, 3)

344

345

# Compute MSE for each output

346

results = multi_mse(preds, target)

347

for i, mse in enumerate(results):

348

print(f"Output {i+1} MSE: {mse:.4f}")

349

```

350

351

### Multi-task Learning

352

353

```python

354

from torchmetrics.wrappers import MultitaskWrapper

355

from torchmetrics import Accuracy, MeanSquaredError

356

357

# Metrics for multi-task learning

358

task_metrics = {

359

'classification': Accuracy(task="multiclass", num_classes=5),

360

'regression': MeanSquaredError()

361

}

362

multitask_metric = MultitaskWrapper(task_metrics)

363

364

# Sample multi-task predictions

365

task_preds = {

366

'classification': torch.randn(32, 5).softmax(dim=-1),

367

'regression': torch.randn(32, 1)

368

}

369

task_targets = {

370

'classification': torch.randint(0, 5, (32,)),

371

'regression': torch.randn(32, 1)

372

}

373

374

# Compute metrics for all tasks

375

task_results = multitask_metric(task_preds, task_targets)

376

for task, result in task_results.items():

377

print(f"{task}: {result:.4f}")

378

```

379

380

### Input Transformation

381

382

```python

383

from torchmetrics.wrappers import LambdaInputTransformer

384

from torchmetrics import Accuracy

385

386

# Transform inputs before metric computation

387

def logits_to_probs(logits):

388

return torch.softmax(logits, dim=-1)

389

390

# Wrap accuracy with input transformation

391

transformed_accuracy = LambdaInputTransformer(

392

Accuracy(task="multiclass", num_classes=3),

393

transform_func=logits_to_probs,

394

transform_labels=False # Don't transform targets

395

)

396

397

# Raw logits input

398

logits = torch.randn(32, 3)

399

target = torch.randint(0, 3, (32,))

400

401

# Accuracy automatically applies softmax

402

acc = transformed_accuracy(logits, target)

403

print(f"Accuracy: {acc:.4f}")

404

```

405

406

## Types

407

408

```python { .api }

409

from typing import Dict, List, Optional, Union, Callable, Any, Tuple

410

import torch

411

from torch import Tensor

412

413

MetricDict = Dict[str, Metric]

414

ComputeGroupsType = Union[bool, List[List[str]]]

415

NaNStrategy = Union["warn", "error", "ignore"]

416

SamplingStrategy = Union["poisson", "multinomial"]

417

```