or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

data.mdfeatures.mdindex.mdlayers.mdmodels.mdtraining.mdutils.md

training.mddocs/

0

# Training Infrastructure

1

2

Comprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.

3

4

## Capabilities

5

6

### Optimizer Creation

7

8

Factory functions for creating optimizers with advanced configurations and parameter grouping strategies.

9

10

```python { .api }

11

def create_optimizer_v2(

12

model_or_params: Union[torch.nn.Module, ParamsT],

13

opt: str = 'sgd',

14

lr: Optional[float] = None,

15

weight_decay: float = 0.0,

16

momentum: float = 0.9,

17

foreach: Optional[bool] = None,

18

filter_bias_and_bn: bool = True,

19

layer_decay: Optional[float] = None,

20

layer_decay_min_scale: float = 0.0,

21

layer_decay_no_opt_scale: Optional[float] = None,

22

param_group_fn: Optional[Callable[[torch.nn.Module], ParamsT]] = None,

23

**kwargs: Any

24

) -> torch.optim.Optimizer:

25

"""

26

Create optimizer with v2 interface.

27

28

Args:

29

model_or_params: Model instance or parameter groups

30

opt: Optimizer name ('sgd', 'adam', 'adamw', 'rmsprop', etc.)

31

lr: Learning rate

32

weight_decay: Weight decay coefficient

33

momentum: Momentum coefficient (for SGD)

34

eps: Epsilon for numerical stability

35

betas: Beta coefficients for Adam-family optimizers

36

opt_args: Additional optimizer arguments

37

**kwargs: Additional arguments

38

39

Returns:

40

Configured optimizer instance

41

"""

42

43

def create_optimizer(

44

args,

45

model: torch.nn.Module,

46

filter_bias_and_bn: bool = True

47

):

48

"""

49

Create optimizer from arguments (legacy interface).

50

51

Args:

52

args: Arguments namespace with optimizer configuration

53

model: Model to optimize

54

filter_bias_and_bn: Filter bias and batch norm parameters

55

56

Returns:

57

Configured optimizer

58

"""

59

60

def list_optimizers() -> List[str]:

61

"""

62

List available optimizer names.

63

64

Returns:

65

List of supported optimizer names

66

"""

67

68

def get_optimizer_class(optimizer_name: str):

69

"""

70

Get optimizer class by name.

71

72

Args:

73

optimizer_name: Name of optimizer

74

75

Returns:

76

Optimizer class

77

"""

78

```

79

80

### Parameter Grouping

81

82

Functions for creating parameter groups with different learning rates, weight decay, and layer-specific configurations.

83

84

```python { .api }

85

def param_groups_layer_decay(

86

model: torch.nn.Module,

87

weight_decay: float = 0.05,

88

no_weight_decay_list: List[str] = None,

89

layer_decay: float = 0.75,

90

end_lr_scale: float = 1.0

91

) -> List[dict]:

92

"""

93

Create parameter groups with layer-wise learning rate decay.

94

95

Args:

96

model: Model to create parameter groups for

97

weight_decay: Base weight decay rate

98

no_weight_decay_list: Parameters to exclude from weight decay

99

layer_decay: Layer decay factor

100

end_lr_scale: Learning rate scale for final layer

101

102

Returns:

103

List of parameter group dictionaries

104

"""

105

106

def param_groups_weight_decay(

107

model: torch.nn.Module,

108

weight_decay: float = 1e-5,

109

no_weight_decay_list: List[str] = None

110

) -> List[dict]:

111

"""

112

Create parameter groups with selective weight decay.

113

114

Args:

115

model: Model to create parameter groups for

116

weight_decay: Weight decay rate

117

no_weight_decay_list: Parameters to exclude from weight decay

118

119

Returns:

120

List of parameter group dictionaries

121

"""

122

```

123

124

## Optimizer Classes

125

126

### Custom Optimizers

127

128

```python { .api }

129

class AdaBelief(torch.optim.Optimizer):

130

"""

131

AdaBelief optimizer.

132

133

Args:

134

params: Iterable of parameters

135

lr: Learning rate

136

betas: Beta coefficients

137

eps: Epsilon for numerical stability

138

weight_decay: Weight decay coefficient

139

amsgrad: Use AMSGrad variant

140

weight_decouple: Decouple weight decay

141

fixed_decay: Use fixed decay

142

rectify: Use rectification

143

"""

144

145

def __init__(

146

self,

147

params,

148

lr: float = 1e-3,

149

betas: tuple = (0.9, 0.999),

150

eps: float = 1e-16,

151

weight_decay: float = 0,

152

amsgrad: bool = False,

153

weight_decouple: bool = True,

154

fixed_decay: bool = False,

155

rectify: bool = True

156

): ...

157

158

class Lamb(torch.optim.Optimizer):

159

"""

160

LAMB (Layer-wise Adaptive Moments) optimizer.

161

162

Args:

163

params: Iterable of parameters

164

lr: Learning rate

165

betas: Beta coefficients

166

eps: Epsilon for numerical stability

167

weight_decay: Weight decay coefficient

168

grad_averaging: Use gradient averaging

169

max_grad_norm: Maximum gradient norm

170

trust_clip: Trust region clipping

171

always_adapt: Always adapt learning rate

172

"""

173

174

def __init__(

175

self,

176

params,

177

lr: float = 1e-3,

178

betas: tuple = (0.9, 0.999),

179

eps: float = 1e-6,

180

weight_decay: float = 0.01,

181

grad_averaging: bool = True,

182

max_grad_norm: float = 1.0,

183

trust_clip: bool = False,

184

always_adapt: bool = False

185

): ...

186

187

class Lion(torch.optim.Optimizer):

188

"""

189

Lion (EvoLved Sign Momentum) optimizer.

190

191

Args:

192

params: Iterable of parameters

193

lr: Learning rate

194

betas: Beta coefficients for momentum

195

weight_decay: Weight decay coefficient

196

use_triton: Use Triton kernel implementation

197

"""

198

199

def __init__(

200

self,

201

params,

202

lr: float = 1e-4,

203

betas: tuple = (0.9, 0.99),

204

weight_decay: float = 0.0,

205

use_triton: bool = False

206

): ...

207

208

class Lookahead(torch.optim.Optimizer):

209

"""

210

Lookahead optimizer wrapper.

211

212

Args:

213

base_optimizer: Base optimizer to wrap

214

alpha: Lookahead step size

215

k: Lookahead frequency

216

pullback_momentum: Pullback momentum mode

217

"""

218

219

def __init__(

220

self,

221

base_optimizer: torch.optim.Optimizer,

222

alpha: float = 0.5,

223

k: int = 6,

224

pullback_momentum: str = "none"

225

): ...

226

```

227

228

## Learning Rate Schedulers

229

230

### Scheduler Creation

231

232

```python { .api }

233

def create_scheduler_v2(

234

optimizer: torch.optim.Optimizer,

235

sched: str = 'cosine',

236

num_epochs: int = 300,

237

decay_epochs: int = 90,

238

decay_milestones: List[int] = (90, 180, 270),

239

cooldown_epochs: int = 0,

240

patience_epochs: int = 10,

241

decay_rate: float = 0.1,

242

min_lr: float = 0,

243

warmup_lr: float = 1e-5,

244

warmup_epochs: int = 0,

245

warmup_prefix: bool = False,

246

noise: Union[float, List[float]] = None,

247

noise_pct: float = 0.67,

248

noise_std: float = 1.0,

249

noise_seed: int = 42,

250

cycle_mul: float = 1.0,

251

cycle_decay: float = 0.1,

252

cycle_limit: int = 1,

253

k_decay: float = 1.0,

254

plateau_mode: str = 'max',

255

step_on_epochs: bool = True,

256

updates_per_epoch: int = 0

257

):

258

"""

259

Create learning rate scheduler with v2 interface.

260

261

Args:

262

optimizer: Optimizer instance

263

sched: Scheduler type ('step', 'cosine', 'tanh', 'poly', 'plateau', etc.)

264

num_epochs: Total number of training epochs

265

decay_epochs: Epochs between learning rate decay

266

decay_rate: Learning rate decay factor

267

min_lr: Minimum learning rate

268

warmup_lr: Warmup initial learning rate

269

warmup_epochs: Number of warmup epochs

270

cooldown_epochs: Number of cooldown epochs

271

patience_epochs: Patience for plateau scheduler

272

cycle_mul: Cycle length multiplier

273

cycle_decay: Cycle decay factor

274

cycle_limit: Maximum number of cycles

275

noise_range: Learning rate noise range

276

noise_pct: Noise percentage

277

noise_std: Noise standard deviation

278

noise_seed: Random seed for noise

279

k_decay: K decay factor

280

plateau_mode: Plateau mode ('min' or 'max')

281

step_on_epochs: Step on epochs vs iterations

282

updates_per_epoch: Updates per epoch for iteration-based stepping

283

**kwargs: Additional scheduler arguments

284

285

Returns:

286

Configured scheduler instance

287

"""

288

289

def scheduler_kwargs(args) -> dict:

290

"""

291

Extract scheduler keyword arguments from args.

292

293

Args:

294

args: Arguments namespace

295

296

Returns:

297

Dictionary of scheduler arguments

298

"""

299

```

300

301

### Scheduler Classes

302

303

```python { .api }

304

class CosineLRScheduler:

305

"""

306

Cosine annealing learning rate scheduler with warm restarts.

307

308

Args:

309

optimizer: Optimizer instance

310

t_initial: Initial number of epochs/iterations

311

lr_min: Minimum learning rate

312

cycle_mul: Cycle length multiplier

313

cycle_decay: Cycle amplitude decay

314

cycle_limit: Maximum number of cycles

315

warmup_t: Warmup iterations

316

warmup_lr_init: Initial warmup learning rate

317

warmup_prefix: Warmup before first cycle

318

t_in_epochs: Interpret t_initial as epochs

319

noise_range_t: Noise range for time

320

noise_pct: Noise percentage

321

noise_std: Noise standard deviation

322

noise_seed: Random seed

323

initialize: Initialize learning rates

324

"""

325

326

def __init__(

327

self,

328

optimizer: torch.optim.Optimizer,

329

t_initial: int,

330

lr_min: float = 0.0,

331

cycle_mul: float = 1.0,

332

cycle_decay: float = 1.0,

333

cycle_limit: int = 1,

334

warmup_t: int = 0,

335

warmup_lr_init: float = 0,

336

warmup_prefix: bool = False,

337

t_in_epochs: bool = True,

338

noise_range_t: tuple = None,

339

noise_pct: float = 0.67,

340

noise_std: float = 1.0,

341

noise_seed: int = None,

342

initialize: bool = True

343

): ...

344

345

class StepLRScheduler:

346

"""

347

Step learning rate scheduler.

348

349

Args:

350

optimizer: Optimizer instance

351

decay_t: Step intervals for decay

352

decay_rate: Decay factor

353

warmup_t: Warmup iterations

354

warmup_lr_init: Initial warmup learning rate

355

t_in_epochs: Interpret intervals as epochs

356

noise_range_t: Noise range for time

357

noise_pct: Noise percentage

358

noise_std: Noise standard deviation

359

noise_seed: Random seed

360

initialize: Initialize learning rates

361

"""

362

363

def __init__(

364

self,

365

optimizer: torch.optim.Optimizer,

366

decay_t: Union[int, List[int]],

367

decay_rate: float = 0.1,

368

warmup_t: int = 0,

369

warmup_lr_init: float = 0,

370

t_in_epochs: bool = True,

371

noise_range_t: tuple = None,

372

noise_pct: float = 0.67,

373

noise_std: float = 1.0,

374

noise_seed: int = None,

375

initialize: bool = True

376

): ...

377

378

class PlateauLRScheduler:

379

"""

380

Plateau-based learning rate scheduler.

381

382

Args:

383

optimizer: Optimizer instance

384

decay_rate: Decay factor when plateau detected

385

patience_t: Patience before decay

386

verbose: Print decay messages

387

threshold: Threshold for measuring improvement

388

cooldown_t: Cooldown period after decay

389

mode: Mode for plateau detection ('min' or 'max')

390

lr_min: Minimum learning rate

391

warmup_t: Warmup iterations

392

warmup_lr_init: Initial warmup learning rate

393

t_in_epochs: Interpret intervals as epochs

394

noise_range_t: Noise range for time

395

noise_pct: Noise percentage

396

noise_std: Noise standard deviation

397

noise_seed: Random seed

398

initialize: Initialize learning rates

399

"""

400

401

def __init__(

402

self,

403

optimizer: torch.optim.Optimizer,

404

decay_rate: float = 0.1,

405

patience_t: int = 10,

406

verbose: bool = True,

407

threshold: float = 1e-4,

408

cooldown_t: int = 0,

409

mode: str = 'max',

410

lr_min: float = 0,

411

warmup_t: int = 0,

412

warmup_lr_init: float = 0,

413

t_in_epochs: bool = True,

414

noise_range_t: tuple = None,

415

noise_pct: float = 0.67,

416

noise_std: float = 1.0,

417

noise_seed: int = None,

418

initialize: bool = True

419

): ...

420

```

421

422

## Loss Functions

423

424

### Loss Classes

425

426

```python { .api }

427

class LabelSmoothingCrossEntropy(torch.nn.Module):

428

"""

429

Cross entropy loss with label smoothing.

430

431

Args:

432

smoothing: Label smoothing factor (0.0 to 1.0)

433

weight: Class weights for unbalanced datasets

434

reduction: Loss reduction ('mean', 'sum', 'none')

435

"""

436

437

def __init__(

438

self,

439

smoothing: float = 0.1,

440

weight: torch.Tensor = None,

441

reduction: str = 'mean'

442

): ...

443

444

class SoftTargetCrossEntropy(torch.nn.Module):

445

"""

446

Cross entropy loss with soft targets (for knowledge distillation).

447

448

Args:

449

weight: Class weights

450

size_average: Deprecated, use reduction

451

ignore_index: Index to ignore in loss computation

452

reduce: Deprecated, use reduction

453

reduction: Loss reduction ('mean', 'sum', 'none')

454

"""

455

456

def __init__(

457

self,

458

weight: torch.Tensor = None,

459

size_average: bool = None,

460

ignore_index: int = -100,

461

reduce: bool = None,

462

reduction: str = 'mean'

463

): ...

464

465

class JsdCrossEntropy(torch.nn.Module):

466

"""

467

Jensen-Shannon divergence cross entropy loss.

468

469

Args:

470

num_splits: Number of augmentation splits

471

alpha: Mixing parameter for splits

472

weight: Class weights

473

size_average: Deprecated, use reduction

474

ignore_index: Index to ignore

475

reduce: Deprecated, use reduction

476

reduction: Loss reduction

477

smoothing: Label smoothing factor

478

"""

479

480

def __init__(

481

self,

482

num_splits: int = 2,

483

alpha: float = 12.0,

484

weight: torch.Tensor = None,

485

size_average: bool = None,

486

ignore_index: int = -100,

487

reduce: bool = None,

488

reduction: str = 'mean',

489

smoothing: float = 0.1

490

): ...

491

492

class BinaryCrossEntropy(torch.nn.Module):

493

"""

494

Binary cross entropy loss with optional smoothing.

495

496

Args:

497

smoothing: Label smoothing factor

498

target_threshold: Threshold for hard targets

499

weight: Class weights

500

reduction: Loss reduction

501

pos_weight: Positive class weight

502

"""

503

504

def __init__(

505

self,

506

smoothing: float = 0.0,

507

target_threshold: float = None,

508

weight: torch.Tensor = None,

509

reduction: str = 'mean',

510

pos_weight: torch.Tensor = None

511

): ...

512

513

class AsymmetricLossMultiLabel(torch.nn.Module):

514

"""

515

Asymmetric loss for multi-label classification.

516

517

Args:

518

gamma_neg: Focusing parameter for negative examples

519

gamma_pos: Focusing parameter for positive examples

520

clip: Clipping value for probability

521

eps: Epsilon for numerical stability

522

disable_torch_grad_focal_loss: Disable gradient computation

523

"""

524

525

def __init__(

526

self,

527

gamma_neg: float = 4,

528

gamma_pos: float = 1,

529

clip: float = 0.05,

530

eps: float = 1e-8,

531

disable_torch_grad_focal_loss: bool = False

532

): ...

533

```

534

535

## Training Utilities

536

537

### Model EMA (Exponential Moving Average)

538

539

```python { .api }

540

class ModelEma:

541

"""

542

Model Exponential Moving Average.

543

544

Args:

545

model: Model to track

546

decay: EMA decay rate

547

device: Device for EMA parameters

548

resume: Resume from checkpoint path

549

"""

550

551

def __init__(

552

self,

553

model: torch.nn.Module,

554

decay: float = 0.9999,

555

device: torch.device = None,

556

resume: str = ''

557

): ...

558

559

def update(self, model: torch.nn.Module) -> None:

560

"""Update EMA parameters."""

561

562

def set(self, model: torch.nn.Module) -> None:

563

"""Set EMA parameters from model."""

564

565

class ModelEmaV2:

566

"""

567

Model EMA v2 with improved decay adjustment.

568

569

Args:

570

model: Model to track

571

decay: Base decay rate

572

decay_type: Decay adjustment type

573

device: Device for EMA parameters

574

"""

575

576

def __init__(

577

self,

578

model: torch.nn.Module,

579

decay: float = 0.9999,

580

decay_type: str = 'exponential',

581

device: torch.device = None

582

): ...

583

```

584

585

### Gradient Utilities

586

587

```python { .api }

588

def adaptive_clip_grad(

589

parameters,

590

clip_factor: float = 0.01,

591

eps: float = 1e-3,

592

norm_type: float = 2.0

593

) -> torch.Tensor:

594

"""

595

Adaptive gradient clipping.

596

597

Args:

598

parameters: Model parameters

599

clip_factor: Adaptive clipping factor

600

eps: Epsilon for numerical stability

601

norm_type: Norm type for gradient computation

602

603

Returns:

604

Gradient norm

605

"""

606

607

def dispatch_clip_grad(

608

parameters,

609

value: float,

610

mode: str = 'norm',

611

norm_type: float = 2.0

612

) -> torch.Tensor:

613

"""

614

Dispatch gradient clipping method.

615

616

Args:

617

parameters: Model parameters

618

value: Clipping value

619

mode: Clipping mode ('norm', 'value', 'agc')

620

norm_type: Norm type for gradient computation

621

622

Returns:

623

Gradient norm

624

"""

625

```

626

627

### Checkpointing

628

629

```python { .api }

630

class CheckpointSaver:

631

"""

632

Model checkpoint saver with configurable retention policy.

633

634

Args:

635

model: Model to save

636

optimizer: Optimizer to save

637

args: Training arguments

638

model_ema: EMA model to save

639

amp_scaler: AMP scaler to save

640

checkpoint_prefix: Checkpoint filename prefix

641

recovery_prefix: Recovery checkpoint prefix

642

checkpoint_dir: Directory for checkpoints

643

recovery_dir: Directory for recovery checkpoints

644

decreasing: Monitor decreasing metric

645

max_history: Maximum checkpoint history

646

unwrap_fn: Function to unwrap model

647

"""

648

649

def __init__(

650

self,

651

model: torch.nn.Module,

652

optimizer: torch.optim.Optimizer,

653

args=None,

654

model_ema: ModelEma = None,

655

amp_scaler=None,

656

checkpoint_prefix: str = 'checkpoint',

657

recovery_prefix: str = 'recovery',

658

checkpoint_dir: str = '',

659

recovery_dir: str = '',

660

decreasing: bool = False,

661

max_history: int = 10,

662

unwrap_fn: Callable = None

663

): ...

664

665

def save_checkpoint(

666

self,

667

epoch: int,

668

metric: float = None

669

) -> str:

670

"""Save checkpoint."""

671

672

def save_recovery(self, epoch: int, batch_idx: int = 0) -> str:

673

"""Save recovery checkpoint."""

674

```

675

676

### Metrics and Monitoring

677

678

```python { .api }

679

class AverageMeter:

680

"""

681

Computes and stores the average and current value.

682

683

Args:

684

name: Meter name

685

fmt: Format string for display

686

"""

687

688

def __init__(self, name: str = '', fmt: str = ':f'): ...

689

690

def reset(self) -> None:

691

"""Reset all statistics."""

692

693

def update(self, val: float, n: int = 1) -> None:

694

"""Update with new value."""

695

696

def accuracy(

697

output: torch.Tensor,

698

target: torch.Tensor,

699

topk: tuple = (1,)

700

) -> List[torch.Tensor]:

701

"""

702

Compute accuracy for specified top-k values.

703

704

Args:

705

output: Model predictions

706

target: Ground truth labels

707

topk: Top-k values to compute

708

709

Returns:

710

List of accuracy values for each k

711

"""

712

```

713

714

## Usage Examples

715

716

### Complete Training Setup

717

718

```python

719

import timm

720

from timm.optim import create_optimizer_v2

721

from timm.scheduler import create_scheduler_v2

722

from timm.loss import LabelSmoothingCrossEntropy

723

from timm.utils import ModelEma, CheckpointSaver, AverageMeter

724

725

# Create model

726

model = timm.create_model('resnet50', pretrained=True, num_classes=1000)

727

728

# Create optimizer with layer decay

729

optimizer = create_optimizer_v2(

730

model,

731

opt='adamw',

732

lr=1e-3,

733

weight_decay=0.05

734

)

735

736

# Create learning rate scheduler

737

scheduler = create_scheduler_v2(

738

optimizer,

739

sched='cosine',

740

num_epochs=100,

741

warmup_epochs=5,

742

warmup_lr=1e-5,

743

min_lr=1e-6

744

)

745

746

# Create loss function

747

criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

748

749

# Create EMA

750

model_ema = ModelEma(model, decay=0.9999)

751

752

# Create checkpoint saver

753

saver = CheckpointSaver(

754

model=model,

755

optimizer=optimizer,

756

model_ema=model_ema,

757

checkpoint_dir='./checkpoints',

758

max_history=5

759

)

760

761

# Metrics

762

losses = AverageMeter('Loss', ':.4e')

763

top1 = AverageMeter('Acc@1', ':6.2f')

764

```

765

766

### Advanced Optimizer Configuration

767

768

```python

769

from timm.optim import param_groups_layer_decay, Lamb, Lookahead

770

771

# Create parameter groups with layer decay

772

param_groups = param_groups_layer_decay(

773

model,

774

weight_decay=0.05,

775

layer_decay=0.8

776

)

777

778

# Create LAMB optimizer

779

base_optimizer = Lamb(param_groups, lr=1e-3)

780

781

# Wrap with Lookahead

782

optimizer = Lookahead(base_optimizer, alpha=0.5, k=6)

783

```

784

785

## Types

786

787

```python { .api }

788

from typing import Optional, Union, List, Dict, Callable, Any, Tuple

789

import torch

790

791

# Optimizer and scheduler types

792

OptimizerType = torch.optim.Optimizer

793

SchedulerType = torch.optim.lr_scheduler._LRScheduler

794

795

# Parameter types

796

ParamGroup = Dict[str, Any]

797

ParamGroups = List[ParamGroup]

798

799

# Loss function type

800

LossFunction = torch.nn.Module

801

802

# Metric types

803

MetricValue = Union[float, torch.Tensor]

804

MetricDict = Dict[str, MetricValue]

805

```