0
# Utilities and Helpers
1
2
General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.
3
4
## Capabilities
5
6
### Model Utilities
7
8
Functions for model management, parameter manipulation, and model state operations.
9
10
```python { .api }
11
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
12
"""
13
Unwrap model from DDP/EMA/other wrappers.
14
15
Args:
16
model: Wrapped model instance
17
18
Returns:
19
Unwrapped base model
20
"""
21
22
def get_state_dict(
23
model: torch.nn.Module,
24
unwrap_fn: Callable = unwrap_model
25
) -> Dict[str, Any]:
26
"""
27
Get model state dictionary with unwrapping.
28
29
Args:
30
model: Model to get state dict from
31
unwrap_fn: Function to unwrap model
32
33
Returns:
34
Model state dictionary
35
"""
36
37
def freeze(model: torch.nn.Module) -> None:
38
"""
39
Freeze all model parameters (disable gradients).
40
41
Args:
42
model: Model to freeze
43
"""
44
45
def unfreeze(model: torch.nn.Module) -> None:
46
"""
47
Unfreeze all model parameters (enable gradients).
48
49
Args:
50
model: Model to unfreeze
51
"""
52
53
def reparameterize_model(
54
model: torch.nn.Module,
55
**kwargs
56
) -> torch.nn.Module:
57
"""
58
Reparameterize model for inference optimization.
59
60
Args:
61
model: Model to reparameterize
62
**kwargs: Reparameterization options
63
64
Returns:
65
Reparameterized model
66
"""
67
```
68
69
### Distributed Training Utilities
70
71
Functions for initializing and managing distributed training across multiple devices and nodes.
72
73
```python { .api }
74
def init_distributed_device(args) -> Tuple[torch.device, int]:
75
"""
76
Initialize distributed training device and process rank.
77
78
Args:
79
args: Arguments namespace with distributed training configuration
80
81
Returns:
82
Tuple of (device, world_size) for distributed training setup
83
"""
84
85
def distribute_bn(
86
model: torch.nn.Module,
87
world_size: int,
88
reduce: bool = False
89
) -> None:
90
"""
91
Distribute batch normalization statistics across processes.
92
93
Args:
94
model: Model with batch norm layers
95
world_size: Number of distributed processes
96
reduce: Reduce statistics across processes
97
"""
98
99
def reduce_tensor(
100
tensor: torch.Tensor,
101
world_size: int = 1
102
) -> torch.Tensor:
103
"""
104
Reduce tensor across distributed processes.
105
106
Args:
107
tensor: Tensor to reduce
108
world_size: Number of processes
109
110
Returns:
111
Reduced tensor
112
"""
113
114
def world_info_from_env() -> Tuple[int, int, int]:
115
"""
116
Get distributed world info from environment variables.
117
118
Returns:
119
Tuple of (local_rank, world_rank, world_size)
120
"""
121
122
def is_distributed_env() -> bool:
123
"""
124
Check if running in distributed environment.
125
126
Returns:
127
True if distributed environment detected
128
"""
129
```
130
131
### Mixed Precision Training
132
133
Utilities for managing mixed precision training with automatic mixed precision (AMP).
134
135
```python { .api }
136
class ApexScaler:
137
"""
138
Gradient scaler using NVIDIA Apex.
139
140
Args:
141
loss_scale: Initial loss scaling factor
142
init_scale: Initial scale value
143
scale_factor: Scale adjustment factor
144
scale_window: Scale adjustment window
145
"""
146
147
def __init__(
148
self,
149
loss_scale: str = 'dynamic',
150
init_scale: float = 2.**16,
151
scale_factor: float = 2.0,
152
scale_window: int = 2000
153
): ...
154
155
def scale_loss(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer): ...
156
def unscale_grads(self, optimizer: torch.optim.Optimizer): ...
157
def update_scale(self, overflow: bool): ...
158
159
class NativeScaler:
160
"""
161
Native PyTorch gradient scaler for mixed precision.
162
163
Args:
164
enabled: Enable gradient scaling
165
init_scale: Initial scaling factor
166
growth_factor: Scale growth factor
167
backoff_factor: Scale backoff factor
168
growth_interval: Interval for scale growth
169
"""
170
171
def __init__(
172
self,
173
enabled: bool = True,
174
init_scale: float = 2.**16,
175
growth_factor: float = 2.0,
176
backoff_factor: float = 0.5,
177
growth_interval: int = 2000
178
): ...
179
180
def scale(self, loss: torch.Tensor) -> torch.Tensor: ...
181
def step(self, optimizer: torch.optim.Optimizer) -> None: ...
182
def update(self) -> None: ...
183
```
184
185
### CUDA and Performance Utilities
186
187
Functions for managing CUDA operations, JIT compilation, and performance optimization.
188
189
```python { .api }
190
def set_jit_legacy(enable: bool) -> None:
191
"""
192
Set legacy JIT mode.
193
194
Args:
195
enable: Enable legacy JIT mode
196
"""
197
198
def set_jit_fuser(fuser_name: str) -> None:
199
"""
200
Set JIT fuser type.
201
202
Args:
203
fuser_name: Name of fuser ('te', 'old', 'nvfuser')
204
"""
205
206
def random_seed(seed: int, rank: int = 0) -> None:
207
"""
208
Set random seed for reproducibility across all libraries.
209
210
Args:
211
seed: Random seed value
212
rank: Process rank for distributed training
213
"""
214
```
215
216
### Logging and Configuration
217
218
Utilities for setting up logging, argument parsing, and experiment configuration.
219
220
```python { .api }
221
def setup_default_logging(
222
default_level: int = logging.INFO,
223
log_path: str = '',
224
**kwargs
225
) -> None:
226
"""
227
Setup default logging configuration.
228
229
Args:
230
default_level: Default logging level
231
log_path: Path for log file
232
**kwargs: Additional logging configuration
233
"""
234
235
def natural_key(string_: str) -> List[Union[int, str]]:
236
"""
237
Natural sorting key function for strings with numbers.
238
239
Args:
240
string_: String to create key for
241
242
Returns:
243
List of components for natural sorting
244
"""
245
246
def add_bool_arg(
247
parser,
248
name: str,
249
default: bool = False,
250
help: str = ''
251
) -> None:
252
"""
253
Add boolean argument to argument parser with --name/--no-name pattern.
254
255
Args:
256
parser: ArgumentParser instance
257
name: Argument name
258
default: Default value
259
help: Help text
260
"""
261
```
262
263
### Training Summary and Output
264
265
Functions for managing training outputs, experiment directories, and result summaries.
266
267
```python { .api }
268
def update_summary(
269
epoch: int,
270
train_metrics: Dict[str, float],
271
eval_metrics: Dict[str, float],
272
filename: str,
273
lr: float = None,
274
write_header: bool = False,
275
log_wandb: bool = False
276
) -> None:
277
"""
278
Update training summary with metrics.
279
280
Args:
281
epoch: Current epoch
282
train_metrics: Training metrics dictionary
283
eval_metrics: Evaluation metrics dictionary
284
filename: Summary file path
285
lr: Current learning rate
286
write_header: Write CSV header
287
log_wandb: Log to Weights & Biases
288
"""
289
290
def get_outdir(path: str, *paths: str, inc: bool = False) -> str:
291
"""
292
Get output directory for experiments.
293
294
Args:
295
path: Base output path
296
*paths: Additional path components
297
inc: Auto-increment directory name
298
299
Returns:
300
Output directory path
301
"""
302
```
303
304
## Training Monitoring Classes
305
306
### Metrics Tracking
307
308
```python { .api }
309
class AverageMeter:
310
"""
311
Computes and stores the average and current value for metrics tracking.
312
313
Args:
314
name: Name of the metric
315
fmt: Format string for display
316
"""
317
318
def __init__(self, name: str = '', fmt: str = ':f'): ...
319
320
def reset(self) -> None:
321
"""Reset all statistics to initial values."""
322
323
def update(self, val: float, n: int = 1) -> None:
324
"""
325
Update meter with new value.
326
327
Args:
328
val: New value to add
329
n: Number of samples the value represents
330
"""
331
332
def __str__(self) -> str:
333
"""String representation of current meter state."""
334
335
def accuracy(
336
output: torch.Tensor,
337
target: torch.Tensor,
338
topk: Tuple[int, ...] = (1,)
339
) -> List[torch.Tensor]:
340
"""
341
Compute accuracy for specified top-k values.
342
343
Args:
344
output: Model output predictions [batch_size, num_classes]
345
target: Ground truth labels [batch_size]
346
topk: Tuple of k values for top-k accuracy
347
348
Returns:
349
List of accuracy tensors for each k value
350
"""
351
```
352
353
### Model EMA Management
354
355
```python { .api }
356
class ModelEma:
357
"""
358
Model Exponential Moving Average for maintaining shadow weights.
359
360
Args:
361
model: Model to track with EMA
362
decay: EMA decay rate (default: 0.9999)
363
device: Device to store EMA parameters
364
resume: Path to resume EMA from checkpoint
365
"""
366
367
def __init__(
368
self,
369
model: torch.nn.Module,
370
decay: float = 0.9999,
371
device: torch.device = None,
372
resume: str = ''
373
): ...
374
375
def update(self, model: torch.nn.Module) -> None:
376
"""
377
Update EMA parameters from model.
378
379
Args:
380
model: Source model for updates
381
"""
382
383
def set(self, model: torch.nn.Module) -> None:
384
"""
385
Set EMA parameters from model (copy all parameters).
386
387
Args:
388
model: Source model to copy from
389
"""
390
391
class ModelEmaV2:
392
"""
393
Model EMA v2 with improved decay adjustment based on training progress.
394
395
Args:
396
model: Model to track
397
decay: Base decay rate
398
decay_type: Type of decay adjustment ('exponential', 'linear')
399
device: Device for EMA parameters
400
"""
401
402
def __init__(
403
self,
404
model: torch.nn.Module,
405
decay: float = 0.9999,
406
decay_type: str = 'exponential',
407
device: torch.device = None
408
): ...
409
410
class ModelEmaV3:
411
"""
412
Model EMA v3 with performance optimizations and memory efficiency.
413
414
Args:
415
model: Model to track
416
decay: EMA decay rate
417
update_after_step: Steps before starting EMA updates
418
use_ema_warmup: Use warmup for EMA updates
419
inv_gamma: Inverse gamma for warmup
420
power: Power for warmup
421
min_value: Minimum decay value
422
device: Device for parameters
423
"""
424
425
def __init__(
426
self,
427
model: torch.nn.Module,
428
decay: float = 0.9999,
429
update_after_step: int = 100,
430
use_ema_warmup: bool = False,
431
inv_gamma: float = 1.0,
432
power: float = 2/3,
433
min_value: float = 0.0,
434
device: torch.device = None
435
): ...
436
```
437
438
### Checkpoint Management
439
440
```python { .api }
441
class CheckpointSaver:
442
"""
443
Saves model checkpoints with configurable retention and recovery policies.
444
445
Args:
446
model: Model to save
447
optimizer: Optimizer state to save
448
args: Training arguments/configuration
449
model_ema: EMA model to save
450
amp_scaler: Mixed precision scaler
451
checkpoint_prefix: Prefix for checkpoint filenames
452
recovery_prefix: Prefix for recovery checkpoints
453
checkpoint_dir: Directory for regular checkpoints
454
recovery_dir: Directory for recovery checkpoints
455
decreasing: Whether monitored metric is decreasing (lower is better)
456
max_history: Maximum number of checkpoints to keep
457
unwrap_fn: Function to unwrap model before saving
458
"""
459
460
def __init__(
461
self,
462
model: torch.nn.Module,
463
optimizer: torch.optim.Optimizer,
464
args = None,
465
model_ema: ModelEma = None,
466
amp_scaler = None,
467
checkpoint_prefix: str = 'checkpoint',
468
recovery_prefix: str = 'recovery',
469
checkpoint_dir: str = '',
470
recovery_dir: str = '',
471
decreasing: bool = False,
472
max_history: int = 10,
473
unwrap_fn: Callable = unwrap_model
474
): ...
475
476
def save_checkpoint(
477
self,
478
epoch: int,
479
metric: float = None
480
) -> Tuple[str, bool]:
481
"""
482
Save checkpoint if metric improved.
483
484
Args:
485
epoch: Current epoch number
486
metric: Metric value for comparison
487
488
Returns:
489
Tuple of (checkpoint_path, is_best)
490
"""
491
492
def save_recovery(
493
self,
494
epoch: int,
495
batch_idx: int = 0
496
) -> str:
497
"""
498
Save recovery checkpoint for resuming interrupted training.
499
500
Args:
501
epoch: Current epoch
502
batch_idx: Current batch index
503
504
Returns:
505
Path to saved recovery checkpoint
506
"""
507
```
508
509
## Usage Examples
510
511
### Basic Training Setup with Utilities
512
513
```python
514
import logging
515
import timm
516
from timm.utils import (
517
setup_default_logging, random_seed, ModelEma,
518
CheckpointSaver, AverageMeter, accuracy
519
)
520
521
# Setup logging
522
setup_default_logging(log_path='training.log')
523
logger = logging.getLogger(__name__)
524
525
# Set random seed for reproducibility
526
random_seed(42, rank=0)
527
528
# Create model and training components
529
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
530
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
531
532
# Setup EMA tracking
533
model_ema = ModelEma(model, decay=0.9999)
534
535
# Setup checkpoint saving
536
saver = CheckpointSaver(
537
model=model,
538
optimizer=optimizer,
539
model_ema=model_ema,
540
checkpoint_dir='./checkpoints',
541
max_history=5,
542
decreasing=False # Higher accuracy is better
543
)
544
545
# Setup metrics tracking
546
losses = AverageMeter('Loss', ':.4e')
547
top1 = AverageMeter('Acc@1', ':6.2f')
548
top5 = AverageMeter('Acc@5', ':6.2f')
549
```
550
551
### Distributed Training Setup
552
553
```python
554
from timm.utils import (
555
init_distributed_device, distribute_bn, reduce_tensor,
556
is_distributed_env
557
)
558
559
# Initialize distributed training
560
device, world_size = init_distributed_device(args)
561
model = model.to(device)
562
563
if is_distributed_env():
564
# Synchronize batch norm statistics
565
distribute_bn(model, world_size, reduce=True)
566
567
# Wrap model for distributed training
568
model = torch.nn.parallel.DistributedDataParallel(
569
model, device_ids=[device], find_unused_parameters=False
570
)
571
572
# In training loop - reduce metrics across processes
573
def train_epoch(model, loader, optimizer, device, world_size):
574
losses = AverageMeter('Loss')
575
576
for batch_idx, (input, target) in enumerate(loader):
577
input, target = input.to(device), target.to(device)
578
579
output = model(input)
580
loss = criterion(output, target)
581
582
# Backward and optimization
583
optimizer.zero_grad()
584
loss.backward()
585
optimizer.step()
586
587
# Reduce loss across processes
588
if world_size > 1:
589
loss = reduce_tensor(loss, world_size)
590
591
losses.update(loss.item(), input.size(0))
592
593
return losses.avg
594
```
595
596
### Mixed Precision Training
597
598
```python
599
from timm.utils import NativeScaler
600
601
# Setup mixed precision training
602
scaler = NativeScaler()
603
model = model.to(device)
604
605
def train_step(model, input, target, optimizer, scaler):
606
optimizer.zero_grad()
607
608
# Forward pass with autocast
609
with torch.cuda.amp.autocast():
610
output = model(input)
611
loss = criterion(output, target)
612
613
# Backward pass with gradient scaling
614
scaler.scale(loss).backward()
615
scaler.step(optimizer)
616
scaler.update()
617
618
return loss.item()
619
```
620
621
### Complete Training Loop with Utilities
622
623
```python
624
def train_model():
625
setup_default_logging()
626
random_seed(42)
627
628
# Model setup
629
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
630
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
631
632
# Training utilities
633
model_ema = ModelEmaV2(model, decay=0.9999)
634
scaler = NativeScaler()
635
saver = CheckpointSaver(
636
model, optimizer, model_ema=model_ema, amp_scaler=scaler,
637
checkpoint_dir='./checkpoints'
638
)
639
640
# Metrics
641
train_losses = AverageMeter('Train Loss')
642
train_acc1 = AverageMeter('Train Acc@1')
643
644
for epoch in range(num_epochs):
645
# Training
646
model.train()
647
train_losses.reset()
648
train_acc1.reset()
649
650
for batch_idx, (input, target) in enumerate(train_loader):
651
input, target = input.to(device), target.to(device)
652
653
# Mixed precision forward pass
654
with torch.cuda.amp.autocast():
655
output = model(input)
656
loss = criterion(output, target)
657
658
# Backward pass
659
optimizer.zero_grad()
660
scaler.scale(loss).backward()
661
scaler.step(optimizer)
662
scaler.update()
663
664
# Update EMA
665
model_ema.update(model)
666
667
# Metrics
668
acc1, acc5 = accuracy(output, target, topk=(1, 5))
669
train_losses.update(loss.item(), input.size(0))
670
train_acc1.update(acc1.item(), input.size(0))
671
672
# Validation and checkpointing
673
val_acc = validate(model_ema.module, val_loader)
674
saver.save_checkpoint(epoch, val_acc)
675
676
logger.info(f'Epoch {epoch}: Train Loss {train_losses.avg:.4f}, '
677
f'Train Acc {train_acc1.avg:.2f}%, Val Acc {val_acc:.2f}%')
678
```
679
680
## Types
681
682
```python { .api }
683
from typing import Optional, Union, List, Dict, Callable, Any, Tuple
684
import torch
685
import logging
686
687
# Device and distributed types
688
DeviceType = torch.device
689
WorldInfo = Tuple[int, int, int] # (local_rank, world_rank, world_size)
690
691
# Metrics types
692
MetricValue = Union[float, torch.Tensor]
693
MetricDict = Dict[str, MetricValue]
694
695
# Checkpoint types
696
CheckpointDict = Dict[str, Any]
697
UnwrapFunction = Callable[[torch.nn.Module], torch.nn.Module]
698
699
# Scaler types
700
LossScaler = Union[torch.cuda.amp.GradScaler, Any]
701
702
# Logging types
703
LogLevel = int
704
Logger = logging.Logger
705
706
# Utility function types
707
SeedFunction = Callable[[int, int], None]
708
ReduceFunction = Callable[[torch.Tensor, int], torch.Tensor]
709
```