0
# Utilities and Aggregation
1
2
Core utilities for metric aggregation, distributed computation, and advanced metric composition including running statistics, bootstrapping, and metric collections for comprehensive evaluation workflows.
3
4
## Capabilities
5
6
### Metric Collection
7
8
Container for organizing and computing multiple metrics simultaneously with automatic synchronization and state management.
9
10
```python { .api }
11
class MetricCollection:
12
def __init__(
13
self,
14
metrics: Union[Dict[str, Metric], List[Metric], Tuple[Metric, ...]],
15
prefix: Optional[str] = None,
16
postfix: Optional[str] = None,
17
compute_groups: Union[bool, List[List[str]]] = True,
18
**kwargs
19
): ...
20
21
def __call__(self, *args, **kwargs) -> Dict[str, Any]: ...
22
def update(self, *args, **kwargs) -> None: ...
23
def compute(self) -> Dict[str, Any]: ...
24
def reset(self) -> None: ...
25
def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": ...
26
```
27
28
### Aggregation Metrics
29
30
Basic metrics for accumulating and aggregating values across batches and distributed processes.
31
32
```python { .api }
33
class MeanMetric(Metric):
34
def __init__(
35
self,
36
nan_strategy: str = "warn",
37
**kwargs
38
): ...
39
40
class SumMetric(Metric):
41
def __init__(
42
self,
43
nan_strategy: str = "warn",
44
**kwargs
45
): ...
46
47
class MaxMetric(Metric):
48
def __init__(
49
self,
50
nan_strategy: str = "warn",
51
**kwargs
52
): ...
53
54
class MinMetric(Metric):
55
def __init__(
56
self,
57
nan_strategy: str = "warn",
58
**kwargs
59
): ...
60
61
class CatMetric(Metric):
62
def __init__(
63
self,
64
nan_strategy: str = "warn",
65
**kwargs
66
): ...
67
```
68
69
### Running Statistics
70
71
Metrics that maintain running statistics over streaming data without storing all values.
72
73
```python { .api }
74
class RunningMean(Metric):
75
def __init__(
76
self,
77
window: int = 100,
78
**kwargs
79
): ...
80
81
class RunningSum(Metric):
82
def __init__(
83
self,
84
window: int = 100,
85
**kwargs
86
): ...
87
```
88
89
### Metric Wrappers
90
91
Advanced wrappers that enhance metric functionality with additional capabilities.
92
93
```python { .api }
94
class BootStrapper:
95
def __init__(
96
self,
97
base_metric: Union[Metric, Callable],
98
num_bootstraps: int = 100,
99
mean: bool = True,
100
std: bool = True,
101
raw: bool = False,
102
quantile: Optional[Union[float, Tensor]] = None,
103
sampling_strategy: str = "poisson",
104
**kwargs
105
): ...
106
107
class ClasswiseWrapper:
108
def __init__(
109
self,
110
metric: Metric,
111
labels: Optional[List[str]] = None,
112
**kwargs
113
): ...
114
115
class MetricTracker:
116
def __init__(
117
self,
118
metric: Metric,
119
maximize: bool = True,
120
**kwargs
121
): ...
122
123
class MinMaxMetric:
124
def __init__(
125
self,
126
base_metric: Metric,
127
**kwargs
128
): ...
129
130
class MultioutputWrapper:
131
def __init__(
132
self,
133
metric: Metric,
134
num_outputs: int,
135
**kwargs
136
): ...
137
138
class MultitaskWrapper:
139
def __init__(
140
self,
141
task_metrics: Dict[str, Metric],
142
**kwargs
143
): ...
144
```
145
146
### Advanced Wrappers
147
148
Specialized wrappers for complex metric computation scenarios.
149
150
```python { .api }
151
class FeatureShare:
152
def __init__(
153
self,
154
metric: Metric,
155
reset_real_features: bool = True,
156
**kwargs
157
): ...
158
159
class LambdaInputTransformer:
160
def __init__(
161
self,
162
metric: Metric,
163
transform_func: Callable,
164
transform_labels: bool = True,
165
**kwargs
166
): ...
167
168
class MetricInputTransformer:
169
def __init__(
170
self,
171
metric: Metric,
172
**kwargs
173
): ...
174
175
class Running:
176
def __init__(
177
self,
178
base_metric: Metric,
179
window_size: int = 100,
180
**kwargs
181
): ...
182
183
class BinaryTargetTransformer:
184
def __init__(
185
self,
186
metric: Metric,
187
target_transform: Callable[[Tensor], Tensor],
188
**kwargs
189
): ...
190
```
191
192
## Usage Examples
193
194
### Basic Metric Collection
195
196
```python
197
import torch
198
from torchmetrics import MetricCollection, Accuracy, F1Score, Precision, Recall
199
200
# Create a collection of classification metrics
201
metric_collection = MetricCollection({
202
'accuracy': Accuracy(task="multiclass", num_classes=3),
203
'f1': F1Score(task="multiclass", num_classes=3),
204
'precision': Precision(task="multiclass", num_classes=3),
205
'recall': Recall(task="multiclass", num_classes=3)
206
})
207
208
# Sample data
209
preds = torch.randn(32, 3).softmax(dim=-1)
210
target = torch.randint(0, 3, (32,))
211
212
# Compute all metrics at once
213
results = metric_collection(preds, target)
214
for metric_name, value in results.items():
215
print(f"{metric_name}: {value:.4f}")
216
```
217
218
### Aggregation Metrics
219
220
```python
221
from torchmetrics import MeanMetric, SumMetric, MaxMetric
222
223
# Initialize aggregation metrics
224
mean_loss = MeanMetric()
225
total_samples = SumMetric()
226
max_confidence = MaxMetric()
227
228
# Accumulate values across batches
229
for batch_idx in range(10):
230
batch_loss = torch.rand(1) * 2 # Random loss
231
batch_size = torch.tensor(32.0) # Batch size
232
batch_max_conf = torch.rand(1) # Max confidence in batch
233
234
mean_loss.update(batch_loss)
235
total_samples.update(batch_size)
236
max_confidence.update(batch_max_conf)
237
238
# Get final aggregated values
239
print(f"Mean Loss: {mean_loss.compute():.4f}")
240
print(f"Total Samples: {total_samples.compute():.0f}")
241
print(f"Max Confidence: {max_confidence.compute():.4f}")
242
```
243
244
### Bootstrapping for Confidence Intervals
245
246
```python
247
from torchmetrics.wrappers import BootStrapper
248
from torchmetrics import Accuracy
249
250
# Bootstrap accuracy for confidence intervals
251
base_accuracy = Accuracy(task="binary")
252
bootstrap_accuracy = BootStrapper(
253
base_accuracy,
254
num_bootstraps=1000,
255
mean=True,
256
std=True,
257
quantile=torch.tensor([0.025, 0.975]) # 95% confidence interval
258
)
259
260
# Sample binary classification data
261
preds = torch.rand(100)
262
target = torch.randint(0, 2, (100,))
263
264
# Compute bootstrapped statistics
265
bootstrap_results = bootstrap_accuracy(preds, target)
266
print(f"Mean Accuracy: {bootstrap_results['mean']:.4f}")
267
print(f"Std Accuracy: {bootstrap_results['std']:.4f}")
268
print(f"95% Confidence Interval: [{bootstrap_results['quantile'][0]:.4f}, {bootstrap_results['quantile'][1]:.4f}]")
269
```
270
271
### Per-Class Metrics
272
273
```python
274
from torchmetrics.wrappers import ClasswiseWrapper
275
from torchmetrics import F1Score
276
277
# Compute F1 score per class
278
class_labels = ['cat', 'dog', 'bird']
279
base_f1 = F1Score(task="multiclass", num_classes=3, average=None)
280
classwise_f1 = ClasswiseWrapper(base_f1, labels=class_labels)
281
282
# Sample multiclass data
283
preds = torch.randn(100, 3).softmax(dim=-1)
284
target = torch.randint(0, 3, (100,))
285
286
# Get per-class results
287
classwise_results = classwise_f1(preds, target)
288
for class_name, f1_score in classwise_results.items():
289
print(f"F1 Score for {class_name}: {f1_score:.4f}")
290
```
291
292
### Metric Tracking
293
294
```python
295
from torchmetrics.wrappers import MetricTracker
296
from torchmetrics import Accuracy
297
298
# Track best accuracy over time
299
tracker = MetricTracker(Accuracy(task="binary"), maximize=True)
300
301
# Simulate training epochs
302
accuracies = [0.6, 0.7, 0.65, 0.8, 0.75, 0.85, 0.82]
303
304
for epoch, acc in enumerate(accuracies):
305
preds = torch.rand(100)
306
target = torch.randint(0, 2, (100,))
307
308
# Update tracker (automatically keeps best)
309
result = tracker(preds, target)
310
311
print(f"Best Accuracy: {tracker.best_metric:.4f}")
312
print(f"Best Accuracy at Epoch: {tracker.best_step}")
313
```
314
315
### Running Statistics
316
317
```python
318
from torchmetrics import RunningMean
319
320
# Running mean with sliding window
321
running_mean = RunningMean(window=50)
322
323
# Simulate streaming data
324
for i in range(200):
325
value = torch.tensor(float(i + torch.randn(1) * 0.1))
326
running_mean.update(value)
327
328
if i % 50 == 0:
329
print(f"Step {i}: Running Mean = {running_mean.compute():.2f}")
330
```
331
332
### Multi-output Wrapper
333
334
```python
335
from torchmetrics.wrappers import MultioutputWrapper
336
from torchmetrics import MeanSquaredError
337
338
# MSE for multi-output regression
339
multi_mse = MultioutputWrapper(MeanSquaredError(), num_outputs=3)
340
341
# Multi-output predictions and targets
342
preds = torch.randn(50, 3) # 50 samples, 3 outputs
343
target = torch.randn(50, 3)
344
345
# Compute MSE for each output
346
results = multi_mse(preds, target)
347
for i, mse in enumerate(results):
348
print(f"Output {i+1} MSE: {mse:.4f}")
349
```
350
351
### Multi-task Learning
352
353
```python
354
from torchmetrics.wrappers import MultitaskWrapper
355
from torchmetrics import Accuracy, MeanSquaredError
356
357
# Metrics for multi-task learning
358
task_metrics = {
359
'classification': Accuracy(task="multiclass", num_classes=5),
360
'regression': MeanSquaredError()
361
}
362
multitask_metric = MultitaskWrapper(task_metrics)
363
364
# Sample multi-task predictions
365
task_preds = {
366
'classification': torch.randn(32, 5).softmax(dim=-1),
367
'regression': torch.randn(32, 1)
368
}
369
task_targets = {
370
'classification': torch.randint(0, 5, (32,)),
371
'regression': torch.randn(32, 1)
372
}
373
374
# Compute metrics for all tasks
375
task_results = multitask_metric(task_preds, task_targets)
376
for task, result in task_results.items():
377
print(f"{task}: {result:.4f}")
378
```
379
380
### Input Transformation
381
382
```python
383
from torchmetrics.wrappers import LambdaInputTransformer
384
from torchmetrics import Accuracy
385
386
# Transform inputs before metric computation
387
def logits_to_probs(logits):
388
return torch.softmax(logits, dim=-1)
389
390
# Wrap accuracy with input transformation
391
transformed_accuracy = LambdaInputTransformer(
392
Accuracy(task="multiclass", num_classes=3),
393
transform_func=logits_to_probs,
394
transform_labels=False # Don't transform targets
395
)
396
397
# Raw logits input
398
logits = torch.randn(32, 3)
399
target = torch.randint(0, 3, (32,))
400
401
# Accuracy automatically applies softmax
402
acc = transformed_accuracy(logits, target)
403
print(f"Accuracy: {acc:.4f}")
404
```
405
406
## Types
407
408
```python { .api }
409
from typing import Dict, List, Optional, Union, Callable, Any, Tuple
410
import torch
411
from torch import Tensor
412
413
MetricDict = Dict[str, Metric]
414
ComputeGroupsType = Union[bool, List[List[str]]]
415
NaNStrategy = Union["warn", "error", "ignore"]
416
SamplingStrategy = Union["poisson", "multinomial"]
417
```