or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

callbacks.mdcollaborative-filtering.mdcore-training.mddata-loading.mdindex.mdinterpretation.mdmedical.mdmetrics-losses.mdtabular.mdtext.mdvision.md

callbacks.mddocs/

0

# Callbacks and Training Customization

1

2

Extensive callback system for customizing the training loop including progress tracking, learning rate scheduling, regularization, logging, and advanced training techniques.

3

4

## Capabilities

5

6

### Core Callback Infrastructure

7

8

Base classes and essential callbacks that form the foundation of fastai's training system.

9

10

```python { .api }

11

class Callback:

12

"""

13

Base class for training callbacks.

14

Callbacks can hook into different points of the training loop.

15

"""

16

17

def __init__(self): ...

18

19

def before_fit(self):

20

"""Called before training starts."""

21

22

def before_epoch(self):

23

"""Called before each epoch."""

24

25

def before_train(self):

26

"""Called before training phase of epoch."""

27

28

def before_batch(self):

29

"""Called before each batch."""

30

31

def after_pred(self):

32

"""Called after model prediction."""

33

34

def after_loss(self):

35

"""Called after loss computation."""

36

37

def before_backward(self):

38

"""Called before backward pass."""

39

40

def after_backward(self):

41

"""Called after backward pass."""

42

43

def after_step(self):

44

"""Called after optimizer step."""

45

46

def after_cancel_batch(self):

47

"""Called if batch is cancelled."""

48

49

def after_batch(self):

50

"""Called after each batch."""

51

52

def after_cancel_train(self):

53

"""Called if training is cancelled."""

54

55

def after_train(self):

56

"""Called after training phase."""

57

58

def before_validate(self):

59

"""Called before validation phase."""

60

61

def after_cancel_validate(self):

62

"""Called if validation is cancelled."""

63

64

def after_validate(self):

65

"""Called after validation phase."""

66

67

def after_cancel_epoch(self):

68

"""Called if epoch is cancelled."""

69

70

def after_epoch(self):

71

"""Called after each epoch."""

72

73

def after_cancel_fit(self):

74

"""Called if training is cancelled."""

75

76

def after_fit(self):

77

"""Called after training completes."""

78

79

class TrainEvalCallback(Callback):

80

"""Handle switching between training and evaluation modes."""

81

82

def before_fit(self): ...

83

def before_train(self): ...

84

def before_validate(self): ...

85

86

class Recorder(Callback):

87

"""Record training statistics and metrics."""

88

89

def before_fit(self): ...

90

def after_batch(self): ...

91

def after_epoch(self): ...

92

93

def plot_loss(self, skip_start=5, with_valid=True): ...

94

def plot_sched(self, keys=None, figsize=None): ...

95

```

96

97

### Learning Rate Scheduling

98

99

Callbacks for sophisticated learning rate scheduling and optimization.

100

101

```python { .api }

102

class OneCycleTraining(Callback):

103

"""

104

One cycle learning rate policy for super-convergence.

105

Cycles learning rate from low to high and back to low.

106

"""

107

108

def __init__(self, max_lr=None, div_factor=25.0, final_div=None,

109

pct_start=0.25, anneal_strategy='cos', cycle_momentum=True,

110

base_momentum=0.85, max_momentum=0.95, wd=None,

111

moms=None, **kwargs):

112

"""

113

Initialize one cycle training.

114

115

Parameters:

116

- max_lr: Maximum learning rate

117

- div_factor: Initial LR divisor (max_lr/div_factor)

118

- final_div: Final LR divisor

119

- pct_start: Percentage of cycle for warmup

120

- anneal_strategy: 'cos' or 'linear' annealing

121

- cycle_momentum: Cycle momentum inverse to LR

122

- base_momentum: Minimum momentum value

123

- max_momentum: Maximum momentum value

124

- wd: Weight decay

125

- moms: Custom momentum schedule

126

"""

127

128

class ReduceLROnPlateau(Callback):

129

"""Reduce learning rate when metric stops improving."""

130

131

def __init__(self, monitor='valid_loss', comp=None, min_delta=0,

132

patience=1, factor=0.2, min_lr=0, reset_on_fit=True):

133

"""

134

Initialize learning rate reduction on plateau.

135

136

Parameters:

137

- monitor: Metric to monitor

138

- comp: Comparison function (np.less for loss, np.greater for accuracy)

139

- min_delta: Minimum change to qualify as improvement

140

- patience: Epochs to wait before reducing

141

- factor: Factor to reduce LR by

142

- min_lr: Minimum learning rate

143

- reset_on_fit: Reset patience counter on new fit

144

"""

145

146

class LRFinder(Callback):

147

"""Learning rate finder for optimal LR discovery."""

148

149

def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp'): ...

150

```

151

152

### Training Enhancement Callbacks

153

154

Callbacks that enhance training stability and performance.

155

156

```python { .api }

157

class MixedPrecision(Callback):

158

"""

159

Automatic mixed precision training for faster training with lower memory usage.

160

Uses float16 for forward pass and float32 for gradients.

161

"""

162

163

def __init__(self, loss_scale=512, flat_master=False, dynamic=True,

164

clip=None, eps=1e-5, scale_wait=500): ...

165

166

class GradientClip(Callback):

167

"""Gradient clipping for training stability."""

168

169

def __init__(self, max_norm=1.0, norm_type=2.0): ...

170

171

class GradientAccumulation(Callback):

172

"""Accumulate gradients over multiple batches before optimizer step."""

173

174

def __init__(self, n_acc=32): ...

175

176

class BnFreeze(Callback):

177

"""Freeze batch normalization layers during training."""

178

179

def before_epoch(self): ...

180

```

181

182

### Monitoring and Logging

183

184

Callbacks for tracking training progress and logging to external services.

185

186

```python { .api }

187

class ProgressCallback(Callback):

188

"""Display training progress with progress bars."""

189

190

def __init__(self, plot=False, display=True): ...

191

192

def before_fit(self): ...

193

def after_batch(self): ...

194

def after_epoch(self): ...

195

196

class CSVLogger(Callback):

197

"""Log training metrics to CSV file."""

198

199

def __init__(self, fname='history.csv', append=False): ...

200

201

def after_epoch(self): ...

202

203

class TensorBoardCallback(Callback):

204

"""Log metrics and model graph to TensorBoard."""

205

206

def __init__(self, log_dir=None, trace_model=True, log_preds=True,

207

n_preds=9, projector=False): ...

208

209

def before_fit(self): ...

210

def after_epoch(self): ...

211

def after_fit(self): ...

212

213

class WandbCallback(Callback):

214

"""Integration with Weights & Biases experiment tracking."""

215

216

def __init__(self, log_preds=True, log_model=True, log_dataset=False,

217

dataset_name=None, valid_idx=1, n_preds=36, seed=12345): ...

218

219

def before_fit(self): ...

220

def after_epoch(self): ...

221

def after_fit(self): ...

222

223

class CometCallback(Callback):

224

"""Integration with Comet.ml experiment tracking."""

225

226

def __init__(self, log_model=True, log_dataset=False, project_name=None,

227

log_code=True, log_preds=True, n_preds=9): ...

228

229

def before_fit(self): ...

230

def after_epoch(self): ...

231

```

232

233

### Model Management Callbacks

234

235

Callbacks for saving, loading, and managing model checkpoints.

236

237

```python { .api }

238

class SaveModelCallback(Callback):

239

"""Save model checkpoints during training."""

240

241

def __init__(self, monitor='valid_loss', comp=None, min_delta=0,

242

fname='bestmodel', every_epoch=False, at_end=False,

243

with_opt=False, reset_on_fit=True):

244

"""

245

Initialize model saving callback.

246

247

Parameters:

248

- monitor: Metric to monitor for best model

249

- comp: Comparison function (np.less for loss)

250

- min_delta: Minimum improvement required

251

- fname: Filename for saved model

252

- every_epoch: Save every epoch

253

- at_end: Save at end of training

254

- with_opt: Include optimizer state

255

- reset_on_fit: Reset best metric on new fit

256

"""

257

258

class EarlyStoppingCallback(Callback):

259

"""Stop training early when metric stops improving."""

260

261

def __init__(self, monitor='valid_loss', comp=None, min_delta=0,

262

patience=1, restore_best_weights=True, reset_on_fit=True):

263

"""

264

Initialize early stopping.

265

266

Parameters:

267

- monitor: Metric to monitor

268

- comp: Comparison function

269

- min_delta: Minimum improvement

270

- patience: Epochs to wait

271

- restore_best_weights: Restore best weights when stopping

272

- reset_on_fit: Reset counter on new fit

273

"""

274

```

275

276

### Regularization and Augmentation

277

278

Callbacks implementing regularization techniques and data augmentation.

279

280

```python { .api }

281

class MixUp(Callback):

282

"""

283

MixUp data augmentation during training.

284

Combines pairs of examples and their labels.

285

"""

286

287

def __init__(self, alpha=0.4, stack_x=False, stack_y=True): ...

288

289

def before_batch(self): ...

290

291

class CutMix(Callback):

292

"""

293

CutMix augmentation combining spatial mixing with MixUp.

294

Cuts and pastes patches between training images.

295

"""

296

297

def __init__(self, alpha=1.0): ...

298

299

def before_batch(self): ...

300

301

class RNNRegularizer(Callback):

302

"""Regularization techniques specific to RNN models."""

303

304

def __init__(self, alpha=2, beta=1, **kwargs): ...

305

306

class ChannelsLast(Callback):

307

"""Memory layout optimization for CNNs."""

308

309

def before_fit(self): ...

310

def before_batch(self): ...

311

```

312

313

### Advanced Training Techniques

314

315

Callbacks implementing advanced training strategies and techniques.

316

317

```python { .api }

318

class LabelSmoothingCrossEntropy(Callback):

319

"""Label smoothing regularization technique."""

320

321

def __init__(self, eps=0.1, reduction='mean'): ...

322

323

class SelfDistillation(Callback):

324

"""Self-distillation training technique."""

325

326

def __init__(self, temperature=3.0, alpha=0.7): ...

327

328

class Lookahead(Callback):

329

"""Lookahead optimizer wrapper."""

330

331

def __init__(self, k=5, alpha=0.5): ...

332

333

class FreezeCallback(Callback):

334

"""Freeze/unfreeze model layers during training."""

335

336

def __init__(self, freeze_epochs=1): ...

337

338

def before_epoch(self): ...

339

340

class ShowGraphCallback(Callback):

341

"""Visualize model architecture and training graphs."""

342

343

def after_fit(self): ...

344

```

345

346

### Custom Callback Utilities

347

348

Utilities for creating and managing custom callbacks.

349

350

```python { .api }

351

def callback_handler(cbs=None, **kwargs):

352

"""Create callback handler with list of callbacks."""

353

354

class CallbackHandler:

355

"""Handler that manages and calls multiple callbacks."""

356

357

def __init__(self, cbs=None): ...

358

359

def add_cb(self, cb): ...

360

def remove_cb(self, cb): ...

361

def __call__(self, event_name): ...

362

363

class CancelFitException(Exception):

364

"""Exception to cancel training."""

365

366

class CancelEpochException(Exception):

367

"""Exception to cancel current epoch."""

368

369

class CancelTrainException(Exception):

370

"""Exception to cancel training phase."""

371

372

class CancelValidException(Exception):

373

"""Exception to cancel validation phase."""

374

375

class CancelBatchException(Exception):

376

"""Exception to cancel current batch."""

377

```

378

379

### Training Control and Debugging Callbacks

380

381

Advanced callbacks for training control, debugging, and model analysis.

382

383

```python { .api }

384

class TerminateOnNaNCallback(Callback):

385

"""

386

Automatically terminate training if loss becomes NaN or infinite.

387

Essential for robust training pipelines.

388

"""

389

order = -9

390

391

def after_batch(self):

392

"""Test if loss is NaN/inf and interrupt training."""

393

394

class ShortEpochCallback(Callback):

395

"""

396

Fit only a percentage of an epoch for debugging/testing.

397

398

Parameters:

399

- pct: Percentage of epoch to train (0.01 = 1%)

400

- short_valid: Whether to also shorten validation

401

"""

402

def __init__(self, pct=0.01, short_valid=True): ...

403

404

class CollectDataCallback(Callback):

405

"""

406

Collect all batches with predictions and losses for debugging.

407

Useful for analyzing model behavior and debugging issues.

408

"""

409

def before_fit(self): ...

410

def after_batch(self): ...

411

```

412

413

### Model Analysis and Hook Callbacks

414

415

Callbacks for analyzing model internals and registering hooks on model layers.

416

417

```python { .api }

418

class ActivationStats(HookCallback):

419

"""

420

Record activation statistics (mean, std, near-zero percentage) during training.

421

Essential for debugging vanishing/exploding gradients and dead neurons.

422

423

Parameters:

424

- with_hist: Whether to record activation histograms

425

"""

426

order = -20

427

428

def __init__(self, with_hist=False, **kwargs): ...

429

def layer_stats(self, idx): ...

430

def hist(self, idx): ...

431

def color_dim(self, idx, figsize=(10,5)): ...

432

def plot_layer_stats(self, idx): ...

433

434

class HookCallback(Callback):

435

"""

436

Base callback for registering hooks on model modules.

437

Foundation for advanced model introspection and analysis.

438

439

Parameters:

440

- modules: Specific modules to hook (None = all with params)

441

- every: Register hooks every N training iterations

442

- remove_end: Remove hooks after training

443

- is_forward: Forward vs backward hooks

444

- detach: Detach tensors from computation graph

445

- cpu: Move hooked data to CPU

446

- include_paramless: Include modules without parameters

447

"""

448

def __init__(self, modules=None, every=None, remove_end=True,

449

is_forward=True, detach=True, cpu=True,

450

include_paramless=False): ...

451

```

452

453

### RNN-Specific Callbacks

454

455

Specialized callbacks for training recurrent neural networks and sequence models.

456

457

```python { .api }

458

class ModelResetter(Callback):

459

"""

460

Reset RNN hidden states between training/validation phases.

461

Essential for proper RNN training with stateful hidden states.

462

"""

463

def before_train(self): ...

464

def before_validate(self): ...

465

def after_fit(self): ...

466

467

class RNNCallback(Callback):

468

"""

469

Handle RNN outputs and save raw/dropout outputs for regularization.

470

Manages the complexities of RNN training loops.

471

"""

472

def after_pred(self): ...

473

```

474

475

### Advanced Prediction and Uncertainty Callbacks

476

477

Callbacks for enhanced prediction gathering and uncertainty estimation.

478

479

```python { .api }

480

class MCDropoutCallback(Callback):

481

"""

482

Enable Monte Carlo Dropout for uncertainty estimation.

483

Keeps dropout layers active during validation for probabilistic predictions.

484

"""

485

def before_validate(self): ...

486

def after_validate(self): ...

487

488

class FetchPredsCallback(Callback):

489

"""

490

Fetch predictions during training loop with callback management.

491

492

Parameters:

493

- ds_idx: Dataset index (0=train, 1=valid)

494

- dl: Custom DataLoader for predictions

495

- with_decoded: Return decoded predictions

496

- cbs: Callbacks to temporarily remove

497

- reorder: Sort prediction results

498

"""

499

def __init__(self, ds_idx=1, dl=None, with_input=False,

500

with_decoded=False, cbs=None, reorder=True): ...

501

```

502

503

### Advanced Mixed Precision Training

504

505

Enhanced mixed precision training with fine-grained control over scaling and gradients.

506

507

```python { .api }

508

class NonNativeMixedPrecision(Callback):

509

"""

510

Manual mixed precision implementation for advanced control.

511

Provides more flexibility than PyTorch's native automatic mixed precision.

512

513

Parameters:

514

- loss_scale: Loss scaling factor for gradient stability

515

- flat_master: Flatten fp32 parameters for performance

516

- dynamic: Automatic loss scale adjustment

517

- max_loss_scale: Maximum loss scale value

518

- div_factor: Scale adjustment factor

519

- scale_wait: Batches to wait before scale increase

520

- clip: Gradient clipping value

521

"""

522

order = 10

523

524

def __init__(self, loss_scale=512, flat_master=False, dynamic=True,

525

max_loss_scale=2.**24, div_factor=2., scale_wait=500, clip=None): ...

526

```

527

528

### Integration and Production Callbacks

529

530

Callbacks for integration with external platforms and production workflows.

531

532

```python { .api }

533

class AzureMLCallback(Callback):

534

"""

535

Integration with Azure Machine Learning for experiment tracking.

536

Automatically logs metrics, parameters, and models to Azure ML.

537

538

Parameters:

539

- learn: Learner instance

540

- log_model: Whether to log the trained model

541

- model_name: Name for the logged model

542

"""

543

def __init__(self, learn=None, log_model=False, model_name='model'): ...

544

545

class CaptumInterpretation:

546

"""

547

Model interpretability using Facebook's Captum library.

548

Provides advanced attribution and visualization methods.

549

550

Parameters:

551

- learn: Learner instance

552

- cmap_name: Colormap name for visualizations

553

- methods: Visualization methods

554

- signs: Attribution signs to display

555

"""

556

def __init__(self, learn, cmap_name='custom blue', colors=None, N=256,

557

methods=('original_image', 'heat_map'), signs=("all", "positive")): ...

558

def visualize(self, inp, metric='IG', n_steps=1000, baseline_type='zeros'): ...

559

def insights(self, inp_data, debug=True): ...

560

```