0
# Training Infrastructure
1
2
Comprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.
3
4
## Capabilities
5
6
### Optimizer Creation
7
8
Factory functions for creating optimizers with advanced configurations and parameter grouping strategies.
9
10
```python { .api }
11
def create_optimizer_v2(
12
model_or_params: Union[torch.nn.Module, ParamsT],
13
opt: str = 'sgd',
14
lr: Optional[float] = None,
15
weight_decay: float = 0.0,
16
momentum: float = 0.9,
17
foreach: Optional[bool] = None,
18
filter_bias_and_bn: bool = True,
19
layer_decay: Optional[float] = None,
20
layer_decay_min_scale: float = 0.0,
21
layer_decay_no_opt_scale: Optional[float] = None,
22
param_group_fn: Optional[Callable[[torch.nn.Module], ParamsT]] = None,
23
**kwargs: Any
24
) -> torch.optim.Optimizer:
25
"""
26
Create optimizer with v2 interface.
27
28
Args:
29
model_or_params: Model instance or parameter groups
30
opt: Optimizer name ('sgd', 'adam', 'adamw', 'rmsprop', etc.)
31
lr: Learning rate
32
weight_decay: Weight decay coefficient
33
momentum: Momentum coefficient (for SGD)
34
eps: Epsilon for numerical stability
35
betas: Beta coefficients for Adam-family optimizers
36
opt_args: Additional optimizer arguments
37
**kwargs: Additional arguments
38
39
Returns:
40
Configured optimizer instance
41
"""
42
43
def create_optimizer(
44
args,
45
model: torch.nn.Module,
46
filter_bias_and_bn: bool = True
47
):
48
"""
49
Create optimizer from arguments (legacy interface).
50
51
Args:
52
args: Arguments namespace with optimizer configuration
53
model: Model to optimize
54
filter_bias_and_bn: Filter bias and batch norm parameters
55
56
Returns:
57
Configured optimizer
58
"""
59
60
def list_optimizers() -> List[str]:
61
"""
62
List available optimizer names.
63
64
Returns:
65
List of supported optimizer names
66
"""
67
68
def get_optimizer_class(optimizer_name: str):
69
"""
70
Get optimizer class by name.
71
72
Args:
73
optimizer_name: Name of optimizer
74
75
Returns:
76
Optimizer class
77
"""
78
```
79
80
### Parameter Grouping
81
82
Functions for creating parameter groups with different learning rates, weight decay, and layer-specific configurations.
83
84
```python { .api }
85
def param_groups_layer_decay(
86
model: torch.nn.Module,
87
weight_decay: float = 0.05,
88
no_weight_decay_list: List[str] = None,
89
layer_decay: float = 0.75,
90
end_lr_scale: float = 1.0
91
) -> List[dict]:
92
"""
93
Create parameter groups with layer-wise learning rate decay.
94
95
Args:
96
model: Model to create parameter groups for
97
weight_decay: Base weight decay rate
98
no_weight_decay_list: Parameters to exclude from weight decay
99
layer_decay: Layer decay factor
100
end_lr_scale: Learning rate scale for final layer
101
102
Returns:
103
List of parameter group dictionaries
104
"""
105
106
def param_groups_weight_decay(
107
model: torch.nn.Module,
108
weight_decay: float = 1e-5,
109
no_weight_decay_list: List[str] = None
110
) -> List[dict]:
111
"""
112
Create parameter groups with selective weight decay.
113
114
Args:
115
model: Model to create parameter groups for
116
weight_decay: Weight decay rate
117
no_weight_decay_list: Parameters to exclude from weight decay
118
119
Returns:
120
List of parameter group dictionaries
121
"""
122
```
123
124
## Optimizer Classes
125
126
### Custom Optimizers
127
128
```python { .api }
129
class AdaBelief(torch.optim.Optimizer):
130
"""
131
AdaBelief optimizer.
132
133
Args:
134
params: Iterable of parameters
135
lr: Learning rate
136
betas: Beta coefficients
137
eps: Epsilon for numerical stability
138
weight_decay: Weight decay coefficient
139
amsgrad: Use AMSGrad variant
140
weight_decouple: Decouple weight decay
141
fixed_decay: Use fixed decay
142
rectify: Use rectification
143
"""
144
145
def __init__(
146
self,
147
params,
148
lr: float = 1e-3,
149
betas: tuple = (0.9, 0.999),
150
eps: float = 1e-16,
151
weight_decay: float = 0,
152
amsgrad: bool = False,
153
weight_decouple: bool = True,
154
fixed_decay: bool = False,
155
rectify: bool = True
156
): ...
157
158
class Lamb(torch.optim.Optimizer):
159
"""
160
LAMB (Layer-wise Adaptive Moments) optimizer.
161
162
Args:
163
params: Iterable of parameters
164
lr: Learning rate
165
betas: Beta coefficients
166
eps: Epsilon for numerical stability
167
weight_decay: Weight decay coefficient
168
grad_averaging: Use gradient averaging
169
max_grad_norm: Maximum gradient norm
170
trust_clip: Trust region clipping
171
always_adapt: Always adapt learning rate
172
"""
173
174
def __init__(
175
self,
176
params,
177
lr: float = 1e-3,
178
betas: tuple = (0.9, 0.999),
179
eps: float = 1e-6,
180
weight_decay: float = 0.01,
181
grad_averaging: bool = True,
182
max_grad_norm: float = 1.0,
183
trust_clip: bool = False,
184
always_adapt: bool = False
185
): ...
186
187
class Lion(torch.optim.Optimizer):
188
"""
189
Lion (EvoLved Sign Momentum) optimizer.
190
191
Args:
192
params: Iterable of parameters
193
lr: Learning rate
194
betas: Beta coefficients for momentum
195
weight_decay: Weight decay coefficient
196
use_triton: Use Triton kernel implementation
197
"""
198
199
def __init__(
200
self,
201
params,
202
lr: float = 1e-4,
203
betas: tuple = (0.9, 0.99),
204
weight_decay: float = 0.0,
205
use_triton: bool = False
206
): ...
207
208
class Lookahead(torch.optim.Optimizer):
209
"""
210
Lookahead optimizer wrapper.
211
212
Args:
213
base_optimizer: Base optimizer to wrap
214
alpha: Lookahead step size
215
k: Lookahead frequency
216
pullback_momentum: Pullback momentum mode
217
"""
218
219
def __init__(
220
self,
221
base_optimizer: torch.optim.Optimizer,
222
alpha: float = 0.5,
223
k: int = 6,
224
pullback_momentum: str = "none"
225
): ...
226
```
227
228
## Learning Rate Schedulers
229
230
### Scheduler Creation
231
232
```python { .api }
233
def create_scheduler_v2(
234
optimizer: torch.optim.Optimizer,
235
sched: str = 'cosine',
236
num_epochs: int = 300,
237
decay_epochs: int = 90,
238
decay_milestones: List[int] = (90, 180, 270),
239
cooldown_epochs: int = 0,
240
patience_epochs: int = 10,
241
decay_rate: float = 0.1,
242
min_lr: float = 0,
243
warmup_lr: float = 1e-5,
244
warmup_epochs: int = 0,
245
warmup_prefix: bool = False,
246
noise: Union[float, List[float]] = None,
247
noise_pct: float = 0.67,
248
noise_std: float = 1.0,
249
noise_seed: int = 42,
250
cycle_mul: float = 1.0,
251
cycle_decay: float = 0.1,
252
cycle_limit: int = 1,
253
k_decay: float = 1.0,
254
plateau_mode: str = 'max',
255
step_on_epochs: bool = True,
256
updates_per_epoch: int = 0
257
):
258
"""
259
Create learning rate scheduler with v2 interface.
260
261
Args:
262
optimizer: Optimizer instance
263
sched: Scheduler type ('step', 'cosine', 'tanh', 'poly', 'plateau', etc.)
264
num_epochs: Total number of training epochs
265
decay_epochs: Epochs between learning rate decay
266
decay_rate: Learning rate decay factor
267
min_lr: Minimum learning rate
268
warmup_lr: Warmup initial learning rate
269
warmup_epochs: Number of warmup epochs
270
cooldown_epochs: Number of cooldown epochs
271
patience_epochs: Patience for plateau scheduler
272
cycle_mul: Cycle length multiplier
273
cycle_decay: Cycle decay factor
274
cycle_limit: Maximum number of cycles
275
noise_range: Learning rate noise range
276
noise_pct: Noise percentage
277
noise_std: Noise standard deviation
278
noise_seed: Random seed for noise
279
k_decay: K decay factor
280
plateau_mode: Plateau mode ('min' or 'max')
281
step_on_epochs: Step on epochs vs iterations
282
updates_per_epoch: Updates per epoch for iteration-based stepping
283
**kwargs: Additional scheduler arguments
284
285
Returns:
286
Configured scheduler instance
287
"""
288
289
def scheduler_kwargs(args) -> dict:
290
"""
291
Extract scheduler keyword arguments from args.
292
293
Args:
294
args: Arguments namespace
295
296
Returns:
297
Dictionary of scheduler arguments
298
"""
299
```
300
301
### Scheduler Classes
302
303
```python { .api }
304
class CosineLRScheduler:
305
"""
306
Cosine annealing learning rate scheduler with warm restarts.
307
308
Args:
309
optimizer: Optimizer instance
310
t_initial: Initial number of epochs/iterations
311
lr_min: Minimum learning rate
312
cycle_mul: Cycle length multiplier
313
cycle_decay: Cycle amplitude decay
314
cycle_limit: Maximum number of cycles
315
warmup_t: Warmup iterations
316
warmup_lr_init: Initial warmup learning rate
317
warmup_prefix: Warmup before first cycle
318
t_in_epochs: Interpret t_initial as epochs
319
noise_range_t: Noise range for time
320
noise_pct: Noise percentage
321
noise_std: Noise standard deviation
322
noise_seed: Random seed
323
initialize: Initialize learning rates
324
"""
325
326
def __init__(
327
self,
328
optimizer: torch.optim.Optimizer,
329
t_initial: int,
330
lr_min: float = 0.0,
331
cycle_mul: float = 1.0,
332
cycle_decay: float = 1.0,
333
cycle_limit: int = 1,
334
warmup_t: int = 0,
335
warmup_lr_init: float = 0,
336
warmup_prefix: bool = False,
337
t_in_epochs: bool = True,
338
noise_range_t: tuple = None,
339
noise_pct: float = 0.67,
340
noise_std: float = 1.0,
341
noise_seed: int = None,
342
initialize: bool = True
343
): ...
344
345
class StepLRScheduler:
346
"""
347
Step learning rate scheduler.
348
349
Args:
350
optimizer: Optimizer instance
351
decay_t: Step intervals for decay
352
decay_rate: Decay factor
353
warmup_t: Warmup iterations
354
warmup_lr_init: Initial warmup learning rate
355
t_in_epochs: Interpret intervals as epochs
356
noise_range_t: Noise range for time
357
noise_pct: Noise percentage
358
noise_std: Noise standard deviation
359
noise_seed: Random seed
360
initialize: Initialize learning rates
361
"""
362
363
def __init__(
364
self,
365
optimizer: torch.optim.Optimizer,
366
decay_t: Union[int, List[int]],
367
decay_rate: float = 0.1,
368
warmup_t: int = 0,
369
warmup_lr_init: float = 0,
370
t_in_epochs: bool = True,
371
noise_range_t: tuple = None,
372
noise_pct: float = 0.67,
373
noise_std: float = 1.0,
374
noise_seed: int = None,
375
initialize: bool = True
376
): ...
377
378
class PlateauLRScheduler:
379
"""
380
Plateau-based learning rate scheduler.
381
382
Args:
383
optimizer: Optimizer instance
384
decay_rate: Decay factor when plateau detected
385
patience_t: Patience before decay
386
verbose: Print decay messages
387
threshold: Threshold for measuring improvement
388
cooldown_t: Cooldown period after decay
389
mode: Mode for plateau detection ('min' or 'max')
390
lr_min: Minimum learning rate
391
warmup_t: Warmup iterations
392
warmup_lr_init: Initial warmup learning rate
393
t_in_epochs: Interpret intervals as epochs
394
noise_range_t: Noise range for time
395
noise_pct: Noise percentage
396
noise_std: Noise standard deviation
397
noise_seed: Random seed
398
initialize: Initialize learning rates
399
"""
400
401
def __init__(
402
self,
403
optimizer: torch.optim.Optimizer,
404
decay_rate: float = 0.1,
405
patience_t: int = 10,
406
verbose: bool = True,
407
threshold: float = 1e-4,
408
cooldown_t: int = 0,
409
mode: str = 'max',
410
lr_min: float = 0,
411
warmup_t: int = 0,
412
warmup_lr_init: float = 0,
413
t_in_epochs: bool = True,
414
noise_range_t: tuple = None,
415
noise_pct: float = 0.67,
416
noise_std: float = 1.0,
417
noise_seed: int = None,
418
initialize: bool = True
419
): ...
420
```
421
422
## Loss Functions
423
424
### Loss Classes
425
426
```python { .api }
427
class LabelSmoothingCrossEntropy(torch.nn.Module):
428
"""
429
Cross entropy loss with label smoothing.
430
431
Args:
432
smoothing: Label smoothing factor (0.0 to 1.0)
433
weight: Class weights for unbalanced datasets
434
reduction: Loss reduction ('mean', 'sum', 'none')
435
"""
436
437
def __init__(
438
self,
439
smoothing: float = 0.1,
440
weight: torch.Tensor = None,
441
reduction: str = 'mean'
442
): ...
443
444
class SoftTargetCrossEntropy(torch.nn.Module):
445
"""
446
Cross entropy loss with soft targets (for knowledge distillation).
447
448
Args:
449
weight: Class weights
450
size_average: Deprecated, use reduction
451
ignore_index: Index to ignore in loss computation
452
reduce: Deprecated, use reduction
453
reduction: Loss reduction ('mean', 'sum', 'none')
454
"""
455
456
def __init__(
457
self,
458
weight: torch.Tensor = None,
459
size_average: bool = None,
460
ignore_index: int = -100,
461
reduce: bool = None,
462
reduction: str = 'mean'
463
): ...
464
465
class JsdCrossEntropy(torch.nn.Module):
466
"""
467
Jensen-Shannon divergence cross entropy loss.
468
469
Args:
470
num_splits: Number of augmentation splits
471
alpha: Mixing parameter for splits
472
weight: Class weights
473
size_average: Deprecated, use reduction
474
ignore_index: Index to ignore
475
reduce: Deprecated, use reduction
476
reduction: Loss reduction
477
smoothing: Label smoothing factor
478
"""
479
480
def __init__(
481
self,
482
num_splits: int = 2,
483
alpha: float = 12.0,
484
weight: torch.Tensor = None,
485
size_average: bool = None,
486
ignore_index: int = -100,
487
reduce: bool = None,
488
reduction: str = 'mean',
489
smoothing: float = 0.1
490
): ...
491
492
class BinaryCrossEntropy(torch.nn.Module):
493
"""
494
Binary cross entropy loss with optional smoothing.
495
496
Args:
497
smoothing: Label smoothing factor
498
target_threshold: Threshold for hard targets
499
weight: Class weights
500
reduction: Loss reduction
501
pos_weight: Positive class weight
502
"""
503
504
def __init__(
505
self,
506
smoothing: float = 0.0,
507
target_threshold: float = None,
508
weight: torch.Tensor = None,
509
reduction: str = 'mean',
510
pos_weight: torch.Tensor = None
511
): ...
512
513
class AsymmetricLossMultiLabel(torch.nn.Module):
514
"""
515
Asymmetric loss for multi-label classification.
516
517
Args:
518
gamma_neg: Focusing parameter for negative examples
519
gamma_pos: Focusing parameter for positive examples
520
clip: Clipping value for probability
521
eps: Epsilon for numerical stability
522
disable_torch_grad_focal_loss: Disable gradient computation
523
"""
524
525
def __init__(
526
self,
527
gamma_neg: float = 4,
528
gamma_pos: float = 1,
529
clip: float = 0.05,
530
eps: float = 1e-8,
531
disable_torch_grad_focal_loss: bool = False
532
): ...
533
```
534
535
## Training Utilities
536
537
### Model EMA (Exponential Moving Average)
538
539
```python { .api }
540
class ModelEma:
541
"""
542
Model Exponential Moving Average.
543
544
Args:
545
model: Model to track
546
decay: EMA decay rate
547
device: Device for EMA parameters
548
resume: Resume from checkpoint path
549
"""
550
551
def __init__(
552
self,
553
model: torch.nn.Module,
554
decay: float = 0.9999,
555
device: torch.device = None,
556
resume: str = ''
557
): ...
558
559
def update(self, model: torch.nn.Module) -> None:
560
"""Update EMA parameters."""
561
562
def set(self, model: torch.nn.Module) -> None:
563
"""Set EMA parameters from model."""
564
565
class ModelEmaV2:
566
"""
567
Model EMA v2 with improved decay adjustment.
568
569
Args:
570
model: Model to track
571
decay: Base decay rate
572
decay_type: Decay adjustment type
573
device: Device for EMA parameters
574
"""
575
576
def __init__(
577
self,
578
model: torch.nn.Module,
579
decay: float = 0.9999,
580
decay_type: str = 'exponential',
581
device: torch.device = None
582
): ...
583
```
584
585
### Gradient Utilities
586
587
```python { .api }
588
def adaptive_clip_grad(
589
parameters,
590
clip_factor: float = 0.01,
591
eps: float = 1e-3,
592
norm_type: float = 2.0
593
) -> torch.Tensor:
594
"""
595
Adaptive gradient clipping.
596
597
Args:
598
parameters: Model parameters
599
clip_factor: Adaptive clipping factor
600
eps: Epsilon for numerical stability
601
norm_type: Norm type for gradient computation
602
603
Returns:
604
Gradient norm
605
"""
606
607
def dispatch_clip_grad(
608
parameters,
609
value: float,
610
mode: str = 'norm',
611
norm_type: float = 2.0
612
) -> torch.Tensor:
613
"""
614
Dispatch gradient clipping method.
615
616
Args:
617
parameters: Model parameters
618
value: Clipping value
619
mode: Clipping mode ('norm', 'value', 'agc')
620
norm_type: Norm type for gradient computation
621
622
Returns:
623
Gradient norm
624
"""
625
```
626
627
### Checkpointing
628
629
```python { .api }
630
class CheckpointSaver:
631
"""
632
Model checkpoint saver with configurable retention policy.
633
634
Args:
635
model: Model to save
636
optimizer: Optimizer to save
637
args: Training arguments
638
model_ema: EMA model to save
639
amp_scaler: AMP scaler to save
640
checkpoint_prefix: Checkpoint filename prefix
641
recovery_prefix: Recovery checkpoint prefix
642
checkpoint_dir: Directory for checkpoints
643
recovery_dir: Directory for recovery checkpoints
644
decreasing: Monitor decreasing metric
645
max_history: Maximum checkpoint history
646
unwrap_fn: Function to unwrap model
647
"""
648
649
def __init__(
650
self,
651
model: torch.nn.Module,
652
optimizer: torch.optim.Optimizer,
653
args=None,
654
model_ema: ModelEma = None,
655
amp_scaler=None,
656
checkpoint_prefix: str = 'checkpoint',
657
recovery_prefix: str = 'recovery',
658
checkpoint_dir: str = '',
659
recovery_dir: str = '',
660
decreasing: bool = False,
661
max_history: int = 10,
662
unwrap_fn: Callable = None
663
): ...
664
665
def save_checkpoint(
666
self,
667
epoch: int,
668
metric: float = None
669
) -> str:
670
"""Save checkpoint."""
671
672
def save_recovery(self, epoch: int, batch_idx: int = 0) -> str:
673
"""Save recovery checkpoint."""
674
```
675
676
### Metrics and Monitoring
677
678
```python { .api }
679
class AverageMeter:
680
"""
681
Computes and stores the average and current value.
682
683
Args:
684
name: Meter name
685
fmt: Format string for display
686
"""
687
688
def __init__(self, name: str = '', fmt: str = ':f'): ...
689
690
def reset(self) -> None:
691
"""Reset all statistics."""
692
693
def update(self, val: float, n: int = 1) -> None:
694
"""Update with new value."""
695
696
def accuracy(
697
output: torch.Tensor,
698
target: torch.Tensor,
699
topk: tuple = (1,)
700
) -> List[torch.Tensor]:
701
"""
702
Compute accuracy for specified top-k values.
703
704
Args:
705
output: Model predictions
706
target: Ground truth labels
707
topk: Top-k values to compute
708
709
Returns:
710
List of accuracy values for each k
711
"""
712
```
713
714
## Usage Examples
715
716
### Complete Training Setup
717
718
```python
719
import timm
720
from timm.optim import create_optimizer_v2
721
from timm.scheduler import create_scheduler_v2
722
from timm.loss import LabelSmoothingCrossEntropy
723
from timm.utils import ModelEma, CheckpointSaver, AverageMeter
724
725
# Create model
726
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
727
728
# Create optimizer with layer decay
729
optimizer = create_optimizer_v2(
730
model,
731
opt='adamw',
732
lr=1e-3,
733
weight_decay=0.05
734
)
735
736
# Create learning rate scheduler
737
scheduler = create_scheduler_v2(
738
optimizer,
739
sched='cosine',
740
num_epochs=100,
741
warmup_epochs=5,
742
warmup_lr=1e-5,
743
min_lr=1e-6
744
)
745
746
# Create loss function
747
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
748
749
# Create EMA
750
model_ema = ModelEma(model, decay=0.9999)
751
752
# Create checkpoint saver
753
saver = CheckpointSaver(
754
model=model,
755
optimizer=optimizer,
756
model_ema=model_ema,
757
checkpoint_dir='./checkpoints',
758
max_history=5
759
)
760
761
# Metrics
762
losses = AverageMeter('Loss', ':.4e')
763
top1 = AverageMeter('Acc@1', ':6.2f')
764
```
765
766
### Advanced Optimizer Configuration
767
768
```python
769
from timm.optim import param_groups_layer_decay, Lamb, Lookahead
770
771
# Create parameter groups with layer decay
772
param_groups = param_groups_layer_decay(
773
model,
774
weight_decay=0.05,
775
layer_decay=0.8
776
)
777
778
# Create LAMB optimizer
779
base_optimizer = Lamb(param_groups, lr=1e-3)
780
781
# Wrap with Lookahead
782
optimizer = Lookahead(base_optimizer, alpha=0.5, k=6)
783
```
784
785
## Types
786
787
```python { .api }
788
from typing import Optional, Union, List, Dict, Callable, Any, Tuple
789
import torch
790
791
# Optimizer and scheduler types
792
OptimizerType = torch.optim.Optimizer
793
SchedulerType = torch.optim.lr_scheduler._LRScheduler
794
795
# Parameter types
796
ParamGroup = Dict[str, Any]
797
ParamGroups = List[ParamGroup]
798
799
# Loss function type
800
LossFunction = torch.nn.Module
801
802
# Metric types
803
MetricValue = Union[float, torch.Tensor]
804
MetricDict = Dict[str, MetricValue]
805
```