or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-features.mddevices-distributed.mdindex.mdmathematical-functions.mdneural-networks.mdtensor-operations.mdtraining.md

training.mddocs/

0

# Training and Optimization

1

2

Optimizers, learning rate schedulers, and training utilities for model optimization and parameter updates. The torch.optim module provides optimization algorithms and learning rate scheduling strategies.

3

4

## Capabilities

5

6

### Optimizers

7

8

Optimization algorithms for updating model parameters during training.

9

10

```python { .api }

11

class Optimizer:

12

"""Base class for all optimizers."""

13

def __init__(self, params, defaults): ...

14

def state_dict(self):

15

"""Return optimizer state dictionary."""

16

def load_state_dict(self, state_dict):

17

"""Load optimizer state."""

18

def zero_grad(self, set_to_none: bool = False):

19

"""Set gradients to zero."""

20

def step(self, closure=None):

21

"""Perform optimization step."""

22

def add_param_group(self, param_group):

23

"""Add parameter group."""

24

```

25

26

### SGD Optimizers

27

28

Stochastic Gradient Descent and variants.

29

30

```python { .api }

31

class SGD(Optimizer):

32

"""Stochastic Gradient Descent optimizer."""

33

def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False):

34

"""

35

Parameters:

36

- params: Iterable of parameters or parameter groups

37

- lr: Learning rate

38

- momentum: Momentum factor (default: 0)

39

- dampening: Dampening for momentum (default: 0)

40

- weight_decay: Weight decay (L2 penalty) (default: 0)

41

- nesterov: Enable Nesterov momentum (default: False)

42

"""

43

def step(self, closure=None): ...

44

45

class ASGD(Optimizer):

46

"""Averaged Stochastic Gradient Descent."""

47

def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0):

48

"""

49

Parameters:

50

- params: Iterable of parameters

51

- lr: Learning rate (default: 1e-2)

52

- lambd: Decay term (default: 1e-4)

53

- alpha: Power for eta update (default: 0.75)

54

- t0: Point at which to start averaging (default: 1e6)

55

- weight_decay: Weight decay (default: 0)

56

"""

57

def step(self, closure=None): ...

58

```

59

60

### Adam-family Optimizers

61

62

Adam and its variants for adaptive learning rates.

63

64

```python { .api }

65

class Adam(Optimizer):

66

"""Adam optimizer."""

67

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False):

68

"""

69

Parameters:

70

- params: Iterable of parameters

71

- lr: Learning rate (default: 1e-3)

72

- betas: Coefficients for momentum and squared gradient averaging (default: (0.9, 0.999))

73

- eps: Term for numerical stability (default: 1e-8)

74

- weight_decay: Weight decay (default: 0)

75

- amsgrad: Use AMSGrad variant (default: False)

76

"""

77

def step(self, closure=None): ...

78

79

class AdamW(Optimizer):

80

"""AdamW optimizer with decoupled weight decay."""

81

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False):

82

"""

83

Parameters:

84

- params: Iterable of parameters

85

- lr: Learning rate (default: 1e-3)

86

- betas: Coefficients for momentum and squared gradient averaging

87

- eps: Term for numerical stability

88

- weight_decay: Weight decay coefficient (default: 1e-2)

89

- amsgrad: Use AMSGrad variant

90

"""

91

def step(self, closure=None): ...

92

93

class Adamax(Optimizer):

94

"""Adamax optimizer (Adam based on infinity norm)."""

95

def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):

96

"""

97

Parameters:

98

- params: Iterable of parameters

99

- lr: Learning rate (default: 2e-3)

100

- betas: Coefficients for momentum and squared gradient averaging

101

- eps: Term for numerical stability

102

- weight_decay: Weight decay

103

"""

104

def step(self, closure=None): ...

105

106

class NAdam(Optimizer):

107

"""NAdam optimizer (Adam with Nesterov momentum)."""

108

def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, momentum_decay=4e-3):

109

"""

110

Parameters:

111

- params: Iterable of parameters

112

- lr: Learning rate (default: 2e-3)

113

- betas: Coefficients for momentum and squared gradient averaging

114

- eps: Term for numerical stability

115

- weight_decay: Weight decay

116

- momentum_decay: Momentum decay

117

"""

118

def step(self, closure=None): ...

119

120

class RAdam(Optimizer):

121

"""RAdam optimizer (Rectified Adam)."""

122

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):

123

"""

124

Parameters:

125

- params: Iterable of parameters

126

- lr: Learning rate (default: 1e-3)

127

- betas: Coefficients for momentum and squared gradient averaging

128

- eps: Term for numerical stability

129

- weight_decay: Weight decay

130

"""

131

def step(self, closure=None): ...

132

```

133

134

### Adaptive Learning Rate Optimizers

135

136

Optimizers that adapt learning rates based on gradient history.

137

138

```python { .api }

139

class Adagrad(Optimizer):

140

"""Adagrad optimizer."""

141

def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10):

142

"""

143

Parameters:

144

- params: Iterable of parameters

145

- lr: Learning rate (default: 1e-2)

146

- lr_decay: Learning rate decay (default: 0)

147

- weight_decay: Weight decay (default: 0)

148

- initial_accumulator_value: Initial value for accumulator

149

- eps: Term for numerical stability

150

"""

151

def step(self, closure=None): ...

152

153

class Adadelta(Optimizer):

154

"""Adadelta optimizer."""

155

def __init__(self, params, lr=1.0, rho=0.9, eps=1e-6, weight_decay=0):

156

"""

157

Parameters:

158

- params: Iterable of parameters

159

- lr: Coefficient that scales delta (default: 1.0)

160

- rho: Coefficient for squared gradient averaging (default: 0.9)

161

- eps: Term for numerical stability (default: 1e-6)

162

- weight_decay: Weight decay (default: 0)

163

"""

164

def step(self, closure=None): ...

165

166

class RMSprop(Optimizer):

167

"""RMSprop optimizer."""

168

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):

169

"""

170

Parameters:

171

- params: Iterable of parameters

172

- lr: Learning rate (default: 1e-2)

173

- alpha: Smoothing constant (default: 0.99)

174

- eps: Term for numerical stability (default: 1e-8)

175

- weight_decay: Weight decay (default: 0)

176

- momentum: Momentum factor (default: 0)

177

- centered: Compute centered RMSprop (default: False)

178

"""

179

def step(self, closure=None): ...

180

181

class Rprop(Optimizer):

182

"""Rprop optimizer."""

183

def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50)):

184

"""

185

Parameters:

186

- params: Iterable of parameters

187

- lr: Learning rate (default: 1e-2)

188

- etas: Pair of (etaminus, etaplus) for multiplicative increase/decrease

189

- step_sizes: Pair of minimal and maximal allowed step sizes

190

"""

191

def step(self, closure=None): ...

192

```

193

194

### Advanced Optimizers

195

196

Specialized optimization algorithms.

197

198

```python { .api }

199

class LBFGS(Optimizer):

200

"""Limited-memory BFGS optimizer."""

201

def __init__(self, params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-7,

202

tolerance_change=1e-9, history_size=100, line_search_fn=None):

203

"""

204

Parameters:

205

- params: Iterable of parameters

206

- lr: Learning rate (default: 1)

207

- max_iter: Maximum number of iterations per optimization step

208

- max_eval: Maximum number of function evaluations per step

209

- tolerance_grad: Termination tolerance on first order optimality

210

- tolerance_change: Termination tolerance on function/parameter changes

211

- history_size: Update history size

212

- line_search_fn: Line search function ('strong_wolfe' or None)

213

"""

214

def step(self, closure): ...

215

216

class SparseAdam(Optimizer):

217

"""Adam optimizer for sparse tensors."""

218

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):

219

"""

220

Parameters:

221

- params: Iterable of parameters

222

- lr: Learning rate (default: 1e-3)

223

- betas: Coefficients for momentum and squared gradient averaging

224

- eps: Term for numerical stability

225

"""

226

def step(self, closure=None): ...

227

228

class Adafactor(Optimizer):

229

"""Adafactor optimizer for memory-efficient training."""

230

def __init__(self, params, lr=None, eps2=1e-30, cliping_threshold=1.0, decay_rate=-0.8,

231

beta1=None, weight_decay=0.0, scale_parameter=True, relative_step=True):

232

"""

233

Parameters:

234

- params: Iterable of parameters

235

- lr: Learning rate (None for automatic scaling)

236

- eps2: Regularization constant for second moment

237

- cliping_threshold: Threshold of root mean square of final gradient update

238

- decay_rate: Coefficient for moving average of squared gradient

239

- beta1: Coefficient for moving average of gradient

240

- weight_decay: Weight decay

241

- scale_parameter: Scale learning rate by root mean square of parameter

242

- relative_step: Set learning rate relative to current step

243

"""

244

def step(self, closure=None): ...

245

```

246

247

### Learning Rate Schedulers

248

249

Learning rate scheduling strategies for training optimization.

250

251

```python { .api }

252

class LRScheduler:

253

"""Base class for learning rate schedulers."""

254

def __init__(self, optimizer, last_epoch=-1, verbose=False): ...

255

def state_dict(self):

256

"""Return scheduler state dictionary."""

257

def load_state_dict(self, state_dict):

258

"""Load scheduler state."""

259

def get_last_lr(self):

260

"""Return last computed learning rates."""

261

def step(self, epoch=None):

262

"""Update learning rates."""

263

264

class StepLR(LRScheduler):

265

"""Decay learning rate by gamma every step_size epochs."""

266

def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):

267

"""

268

Parameters:

269

- optimizer: Wrapped optimizer

270

- step_size: Period of learning rate decay

271

- gamma: Multiplicative factor of learning rate decay (default: 0.1)

272

- last_epoch: Index of last epoch (default: -1)

273

- verbose: Print message on every update (default: False)

274

"""

275

276

class MultiStepLR(LRScheduler):

277

"""Decay learning rate by gamma at specified milestones."""

278

def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):

279

"""

280

Parameters:

281

- optimizer: Wrapped optimizer

282

- milestones: List of epoch indices for decay

283

- gamma: Multiplicative factor of learning rate decay

284

- last_epoch: Index of last epoch

285

- verbose: Print message on every update

286

"""

287

288

class ExponentialLR(LRScheduler):

289

"""Decay learning rate by gamma every epoch."""

290

def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):

291

"""

292

Parameters:

293

- optimizer: Wrapped optimizer

294

- gamma: Multiplicative factor of learning rate decay

295

- last_epoch: Index of last epoch

296

- verbose: Print message on every update

297

"""

298

299

class CosineAnnealingLR(LRScheduler):

300

"""Cosine annealing learning rate schedule."""

301

def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):

302

"""

303

Parameters:

304

- optimizer: Wrapped optimizer

305

- T_max: Maximum number of iterations

306

- eta_min: Minimum learning rate (default: 0)

307

- last_epoch: Index of last epoch

308

- verbose: Print message on every update

309

"""

310

311

class CosineAnnealingWarmRestarts(LRScheduler):

312

"""Cosine annealing with warm restarts."""

313

def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):

314

"""

315

Parameters:

316

- optimizer: Wrapped optimizer

317

- T_0: Number of iterations for first restart

318

- T_mult: Factor to increase T_i after restart (default: 1)

319

- eta_min: Minimum learning rate (default: 0)

320

- last_epoch: Index of last epoch

321

- verbose: Print message on every update

322

"""

323

324

class ReduceLROnPlateau:

325

"""Reduce learning rate when metric stops improving."""

326

def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False,

327

threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8):

328

"""

329

Parameters:

330

- optimizer: Wrapped optimizer

331

- mode: 'min' or 'max' for metric improvement direction

332

- factor: Factor to reduce learning rate (default: 0.1)

333

- patience: Number of epochs with no improvement to wait

334

- verbose: Print message when reducing lr

335

- threshold: Threshold for measuring new optimum

336

- threshold_mode: 'rel' or 'abs' for threshold comparison

337

- cooldown: Number of epochs to wait before resuming normal operation

338

- min_lr: Lower bound on learning rate

339

- eps: Minimal decay applied to lr

340

"""

341

def step(self, metrics, epoch=None): ...

342

343

class CyclicLR(LRScheduler):

344

"""Cyclical learning rate policy."""

345

def __init__(self, optimizer, base_lr, max_lr, step_size_up=2000, step_size_down=None,

346

mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True,

347

base_momentum=0.8, max_momentum=0.9, last_epoch=-1, verbose=False):

348

"""

349

Parameters:

350

- optimizer: Wrapped optimizer

351

- base_lr: Lower learning rate boundary

352

- max_lr: Upper learning rate boundary

353

- step_size_up: Number of training iterations in increasing half

354

- step_size_down: Number of training iterations in decreasing half

355

- mode: 'triangular', 'triangular2', or 'exp_range'

356

- gamma: Constant in 'exp_range' scaling function

357

- scale_fn: Custom scaling policy function

358

- scale_mode: 'cycle' or 'iterations'

359

- cycle_momentum: Cycle momentum inversely to learning rate

360

- base_momentum: Lower momentum boundary

361

- max_momentum: Upper momentum boundary

362

- last_epoch: Index of last epoch

363

- verbose: Print message on every update

364

"""

365

366

class OneCycleLR(LRScheduler):

367

"""One cycle learning rate policy."""

368

def __init__(self, optimizer, max_lr, total_steps=None, epochs=None, steps_per_epoch=None,

369

pct_start=0.3, anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85,

370

max_momentum=0.95, div_factor=25.0, final_div_factor=1e4, three_phase=False, last_epoch=-1, verbose=False):

371

"""

372

Parameters:

373

- optimizer: Wrapped optimizer

374

- max_lr: Upper learning rate boundary

375

- total_steps: Total number of steps in cycle

376

- epochs: Number of epochs (alternative to total_steps)

377

- steps_per_epoch: Steps per epoch (with epochs)

378

- pct_start: Percentage of cycle spent increasing learning rate

379

- anneal_strategy: 'cos' or 'linear' annealing strategy

380

- cycle_momentum: Cycle momentum inversely to learning rate

381

- base_momentum: Lower momentum boundary

382

- max_momentum: Upper momentum boundary

383

- div_factor: Determines initial learning rate (max_lr/div_factor)

384

- final_div_factor: Determines minimum learning rate (max_lr/(div_factor*final_div_factor))

385

- three_phase: Use three phase schedule

386

- last_epoch: Index of last epoch

387

- verbose: Print message on every update

388

"""

389

```

390

391

### Gradient Processing

392

393

Utilities for gradient manipulation and processing.

394

395

```python { .api }

396

def clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False):

397

"""

398

Clip gradient norm of parameters.

399

400

Parameters:

401

- parameters: Iterable of parameters or single tensor

402

- max_norm: Maximum norm of gradients

403

- norm_type: Type of norm (default: 2.0)

404

- error_if_nonfinite: Raise error if total norm is NaN or inf

405

406

Returns:

407

Total norm of the parameters

408

"""

409

410

def clip_grad_value_(parameters, clip_value):

411

"""

412

Clip gradient values to specified range.

413

414

Parameters:

415

- parameters: Iterable of parameters or single tensor

416

- clip_value: Maximum absolute value for gradients

417

"""

418

```

419

420

### Stochastic Weight Averaging

421

422

Utilities for stochastic weight averaging to improve generalization.

423

424

```python { .api }

425

class AveragedModel(nn.Module):

426

"""Averaged model for stochastic weight averaging."""

427

def __init__(self, model, device=None, avg_fn=None, use_buffers=False):

428

"""

429

Parameters:

430

- model: Model to average

431

- device: Device to store averaged parameters

432

- avg_fn: Function to compute running average

433

- use_buffers: Whether to average buffers

434

"""

435

def update_parameters(self, model): ...

436

437

class SWALR(LRScheduler):

438

"""Learning rate scheduler for stochastic weight averaging."""

439

def __init__(self, optimizer, swa_lr, anneal_epochs=10, anneal_strategy='cos', last_epoch=-1):

440

"""

441

Parameters:

442

- optimizer: Wrapped optimizer

443

- swa_lr: SWA learning rate

444

- anneal_epochs: Number of epochs for annealing (default: 10)

445

- anneal_strategy: 'cos' or 'linear' annealing strategy

446

- last_epoch: Index of last epoch

447

"""

448

```

449

450

## Usage Examples

451

452

### Basic Training Loop

453

454

```python

455

import torch

456

import torch.nn as nn

457

import torch.optim as optim

458

from torch.utils.data import DataLoader

459

460

# Setup model, loss, and optimizer

461

model = nn.Sequential(

462

nn.Linear(784, 128),

463

nn.ReLU(),

464

nn.Linear(128, 10)

465

)

466

criterion = nn.CrossEntropyLoss()

467

optimizer = optim.Adam(model.parameters(), lr=0.001)

468

469

# Training loop

470

def train_epoch(model, dataloader, criterion, optimizer):

471

model.train()

472

total_loss = 0

473

474

for batch_idx, (data, targets) in enumerate(dataloader):

475

# Zero gradients

476

optimizer.zero_grad()

477

478

# Forward pass

479

outputs = model(data)

480

loss = criterion(outputs, targets)

481

482

# Backward pass

483

loss.backward()

484

485

# Gradient clipping (optional)

486

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

487

488

# Update parameters

489

optimizer.step()

490

491

total_loss += loss.item()

492

493

return total_loss / len(dataloader)

494

495

# Example usage

496

# train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

497

# loss = train_epoch(model, train_loader, criterion, optimizer)

498

# print(f"Training loss: {loss:.4f}")

499

```

500

501

### Learning Rate Scheduling

502

503

```python

504

import torch

505

import torch.optim as optim

506

from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

507

508

# Setup optimizer and scheduler

509

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

510

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

511

512

# Alternative: Reduce on plateau

513

# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

514

515

# Training loop with scheduler

516

for epoch in range(100):

517

train_loss = train_epoch(model, train_loader, criterion, optimizer)

518

val_loss = validate(model, val_loader, criterion)

519

520

# Step scheduler

521

scheduler.step() # For StepLR

522

# scheduler.step(val_loss) # For ReduceLROnPlateau

523

524

current_lr = optimizer.param_groups[0]['lr']

525

print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}")

526

```

527

528

### Advanced Optimization with Multiple Parameter Groups

529

530

```python

531

import torch

532

import torch.optim as optim

533

534

# Different learning rates for different parts of the model

535

model = nn.Sequential(

536

nn.Linear(784, 128),

537

nn.ReLU(),

538

nn.Linear(128, 10)

539

)

540

541

# Create parameter groups

542

params = [

543

{'params': model[0].parameters(), 'lr': 0.001}, # First layer

544

{'params': model[2].parameters(), 'lr': 0.01} # Last layer

545

]

546

547

optimizer = optim.Adam(params, weight_decay=1e-4)

548

549

# Training with different learning rates

550

for epoch in range(100):

551

for batch_idx, (data, targets) in enumerate(train_loader):

552

optimizer.zero_grad()

553

outputs = model(data)

554

loss = criterion(outputs, targets)

555

loss.backward()

556

optimizer.step()

557

```

558

559

### Stochastic Weight Averaging

560

561

```python

562

import torch

563

import torch.optim as optim

564

from torch.optim.swa_utils import AveragedModel, SWALR

565

566

# Setup model and optimizer

567

model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1))

568

optimizer = optim.SGD(model.parameters(), lr=0.1)

569

570

# Create averaged model and SWA scheduler

571

swa_model = AveragedModel(model)

572

swa_scheduler = SWALR(optimizer, swa_lr=0.05)

573

574

# Training with SWA

575

swa_start_epoch = 80

576

for epoch in range(100):

577

train_loss = train_epoch(model, train_loader, criterion, optimizer)

578

579

if epoch >= swa_start_epoch:

580

swa_model.update_parameters(model)

581

swa_scheduler.step()

582

else:

583

# Regular scheduler before SWA

584

regular_scheduler.step()

585

586

print(f"Epoch {epoch}: Loss: {train_loss:.4f}")

587

588

# Update SWA batch normalization statistics

589

torch.optim.swa_utils.update_bn(train_loader, swa_model)

590

591

# Use SWA model for inference

592

swa_model.eval()

593

```

594

595

### One Cycle Learning Rate Policy

596

597

```python

598

import torch

599

import torch.optim as optim

600

from torch.optim.lr_scheduler import OneCycleLR

601

602

# Setup optimizer

603

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

604

605

# One cycle scheduler

606

steps_per_epoch = len(train_loader)

607

scheduler = OneCycleLR(

608

optimizer,

609

max_lr=0.1,

610

epochs=100,

611

steps_per_epoch=steps_per_epoch,

612

pct_start=0.3,

613

div_factor=25,

614

final_div_factor=1e4

615

)

616

617

# Training loop

618

for epoch in range(100):

619

for batch_idx, (data, targets) in enumerate(train_loader):

620

optimizer.zero_grad()

621

outputs = model(data)

622

loss = criterion(outputs, targets)

623

loss.backward()

624

optimizer.step()

625

626

# Step after each batch

627

scheduler.step()

628

629

print(f"Epoch {epoch}: LR: {optimizer.param_groups[0]['lr']:.6f}")

630

```