0
# Callbacks and Training Customization
1
2
Extensive callback system for customizing the training loop including progress tracking, learning rate scheduling, regularization, logging, and advanced training techniques.
3
4
## Capabilities
5
6
### Core Callback Infrastructure
7
8
Base classes and essential callbacks that form the foundation of fastai's training system.
9
10
```python { .api }
11
class Callback:
12
"""
13
Base class for training callbacks.
14
Callbacks can hook into different points of the training loop.
15
"""
16
17
def __init__(self): ...
18
19
def before_fit(self):
20
"""Called before training starts."""
21
22
def before_epoch(self):
23
"""Called before each epoch."""
24
25
def before_train(self):
26
"""Called before training phase of epoch."""
27
28
def before_batch(self):
29
"""Called before each batch."""
30
31
def after_pred(self):
32
"""Called after model prediction."""
33
34
def after_loss(self):
35
"""Called after loss computation."""
36
37
def before_backward(self):
38
"""Called before backward pass."""
39
40
def after_backward(self):
41
"""Called after backward pass."""
42
43
def after_step(self):
44
"""Called after optimizer step."""
45
46
def after_cancel_batch(self):
47
"""Called if batch is cancelled."""
48
49
def after_batch(self):
50
"""Called after each batch."""
51
52
def after_cancel_train(self):
53
"""Called if training is cancelled."""
54
55
def after_train(self):
56
"""Called after training phase."""
57
58
def before_validate(self):
59
"""Called before validation phase."""
60
61
def after_cancel_validate(self):
62
"""Called if validation is cancelled."""
63
64
def after_validate(self):
65
"""Called after validation phase."""
66
67
def after_cancel_epoch(self):
68
"""Called if epoch is cancelled."""
69
70
def after_epoch(self):
71
"""Called after each epoch."""
72
73
def after_cancel_fit(self):
74
"""Called if training is cancelled."""
75
76
def after_fit(self):
77
"""Called after training completes."""
78
79
class TrainEvalCallback(Callback):
80
"""Handle switching between training and evaluation modes."""
81
82
def before_fit(self): ...
83
def before_train(self): ...
84
def before_validate(self): ...
85
86
class Recorder(Callback):
87
"""Record training statistics and metrics."""
88
89
def before_fit(self): ...
90
def after_batch(self): ...
91
def after_epoch(self): ...
92
93
def plot_loss(self, skip_start=5, with_valid=True): ...
94
def plot_sched(self, keys=None, figsize=None): ...
95
```
96
97
### Learning Rate Scheduling
98
99
Callbacks for sophisticated learning rate scheduling and optimization.
100
101
```python { .api }
102
class OneCycleTraining(Callback):
103
"""
104
One cycle learning rate policy for super-convergence.
105
Cycles learning rate from low to high and back to low.
106
"""
107
108
def __init__(self, max_lr=None, div_factor=25.0, final_div=None,
109
pct_start=0.25, anneal_strategy='cos', cycle_momentum=True,
110
base_momentum=0.85, max_momentum=0.95, wd=None,
111
moms=None, **kwargs):
112
"""
113
Initialize one cycle training.
114
115
Parameters:
116
- max_lr: Maximum learning rate
117
- div_factor: Initial LR divisor (max_lr/div_factor)
118
- final_div: Final LR divisor
119
- pct_start: Percentage of cycle for warmup
120
- anneal_strategy: 'cos' or 'linear' annealing
121
- cycle_momentum: Cycle momentum inverse to LR
122
- base_momentum: Minimum momentum value
123
- max_momentum: Maximum momentum value
124
- wd: Weight decay
125
- moms: Custom momentum schedule
126
"""
127
128
class ReduceLROnPlateau(Callback):
129
"""Reduce learning rate when metric stops improving."""
130
131
def __init__(self, monitor='valid_loss', comp=None, min_delta=0,
132
patience=1, factor=0.2, min_lr=0, reset_on_fit=True):
133
"""
134
Initialize learning rate reduction on plateau.
135
136
Parameters:
137
- monitor: Metric to monitor
138
- comp: Comparison function (np.less for loss, np.greater for accuracy)
139
- min_delta: Minimum change to qualify as improvement
140
- patience: Epochs to wait before reducing
141
- factor: Factor to reduce LR by
142
- min_lr: Minimum learning rate
143
- reset_on_fit: Reset patience counter on new fit
144
"""
145
146
class LRFinder(Callback):
147
"""Learning rate finder for optimal LR discovery."""
148
149
def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp'): ...
150
```
151
152
### Training Enhancement Callbacks
153
154
Callbacks that enhance training stability and performance.
155
156
```python { .api }
157
class MixedPrecision(Callback):
158
"""
159
Automatic mixed precision training for faster training with lower memory usage.
160
Uses float16 for forward pass and float32 for gradients.
161
"""
162
163
def __init__(self, loss_scale=512, flat_master=False, dynamic=True,
164
clip=None, eps=1e-5, scale_wait=500): ...
165
166
class GradientClip(Callback):
167
"""Gradient clipping for training stability."""
168
169
def __init__(self, max_norm=1.0, norm_type=2.0): ...
170
171
class GradientAccumulation(Callback):
172
"""Accumulate gradients over multiple batches before optimizer step."""
173
174
def __init__(self, n_acc=32): ...
175
176
class BnFreeze(Callback):
177
"""Freeze batch normalization layers during training."""
178
179
def before_epoch(self): ...
180
```
181
182
### Monitoring and Logging
183
184
Callbacks for tracking training progress and logging to external services.
185
186
```python { .api }
187
class ProgressCallback(Callback):
188
"""Display training progress with progress bars."""
189
190
def __init__(self, plot=False, display=True): ...
191
192
def before_fit(self): ...
193
def after_batch(self): ...
194
def after_epoch(self): ...
195
196
class CSVLogger(Callback):
197
"""Log training metrics to CSV file."""
198
199
def __init__(self, fname='history.csv', append=False): ...
200
201
def after_epoch(self): ...
202
203
class TensorBoardCallback(Callback):
204
"""Log metrics and model graph to TensorBoard."""
205
206
def __init__(self, log_dir=None, trace_model=True, log_preds=True,
207
n_preds=9, projector=False): ...
208
209
def before_fit(self): ...
210
def after_epoch(self): ...
211
def after_fit(self): ...
212
213
class WandbCallback(Callback):
214
"""Integration with Weights & Biases experiment tracking."""
215
216
def __init__(self, log_preds=True, log_model=True, log_dataset=False,
217
dataset_name=None, valid_idx=1, n_preds=36, seed=12345): ...
218
219
def before_fit(self): ...
220
def after_epoch(self): ...
221
def after_fit(self): ...
222
223
class CometCallback(Callback):
224
"""Integration with Comet.ml experiment tracking."""
225
226
def __init__(self, log_model=True, log_dataset=False, project_name=None,
227
log_code=True, log_preds=True, n_preds=9): ...
228
229
def before_fit(self): ...
230
def after_epoch(self): ...
231
```
232
233
### Model Management Callbacks
234
235
Callbacks for saving, loading, and managing model checkpoints.
236
237
```python { .api }
238
class SaveModelCallback(Callback):
239
"""Save model checkpoints during training."""
240
241
def __init__(self, monitor='valid_loss', comp=None, min_delta=0,
242
fname='bestmodel', every_epoch=False, at_end=False,
243
with_opt=False, reset_on_fit=True):
244
"""
245
Initialize model saving callback.
246
247
Parameters:
248
- monitor: Metric to monitor for best model
249
- comp: Comparison function (np.less for loss)
250
- min_delta: Minimum improvement required
251
- fname: Filename for saved model
252
- every_epoch: Save every epoch
253
- at_end: Save at end of training
254
- with_opt: Include optimizer state
255
- reset_on_fit: Reset best metric on new fit
256
"""
257
258
class EarlyStoppingCallback(Callback):
259
"""Stop training early when metric stops improving."""
260
261
def __init__(self, monitor='valid_loss', comp=None, min_delta=0,
262
patience=1, restore_best_weights=True, reset_on_fit=True):
263
"""
264
Initialize early stopping.
265
266
Parameters:
267
- monitor: Metric to monitor
268
- comp: Comparison function
269
- min_delta: Minimum improvement
270
- patience: Epochs to wait
271
- restore_best_weights: Restore best weights when stopping
272
- reset_on_fit: Reset counter on new fit
273
"""
274
```
275
276
### Regularization and Augmentation
277
278
Callbacks implementing regularization techniques and data augmentation.
279
280
```python { .api }
281
class MixUp(Callback):
282
"""
283
MixUp data augmentation during training.
284
Combines pairs of examples and their labels.
285
"""
286
287
def __init__(self, alpha=0.4, stack_x=False, stack_y=True): ...
288
289
def before_batch(self): ...
290
291
class CutMix(Callback):
292
"""
293
CutMix augmentation combining spatial mixing with MixUp.
294
Cuts and pastes patches between training images.
295
"""
296
297
def __init__(self, alpha=1.0): ...
298
299
def before_batch(self): ...
300
301
class RNNRegularizer(Callback):
302
"""Regularization techniques specific to RNN models."""
303
304
def __init__(self, alpha=2, beta=1, **kwargs): ...
305
306
class ChannelsLast(Callback):
307
"""Memory layout optimization for CNNs."""
308
309
def before_fit(self): ...
310
def before_batch(self): ...
311
```
312
313
### Advanced Training Techniques
314
315
Callbacks implementing advanced training strategies and techniques.
316
317
```python { .api }
318
class LabelSmoothingCrossEntropy(Callback):
319
"""Label smoothing regularization technique."""
320
321
def __init__(self, eps=0.1, reduction='mean'): ...
322
323
class SelfDistillation(Callback):
324
"""Self-distillation training technique."""
325
326
def __init__(self, temperature=3.0, alpha=0.7): ...
327
328
class Lookahead(Callback):
329
"""Lookahead optimizer wrapper."""
330
331
def __init__(self, k=5, alpha=0.5): ...
332
333
class FreezeCallback(Callback):
334
"""Freeze/unfreeze model layers during training."""
335
336
def __init__(self, freeze_epochs=1): ...
337
338
def before_epoch(self): ...
339
340
class ShowGraphCallback(Callback):
341
"""Visualize model architecture and training graphs."""
342
343
def after_fit(self): ...
344
```
345
346
### Custom Callback Utilities
347
348
Utilities for creating and managing custom callbacks.
349
350
```python { .api }
351
def callback_handler(cbs=None, **kwargs):
352
"""Create callback handler with list of callbacks."""
353
354
class CallbackHandler:
355
"""Handler that manages and calls multiple callbacks."""
356
357
def __init__(self, cbs=None): ...
358
359
def add_cb(self, cb): ...
360
def remove_cb(self, cb): ...
361
def __call__(self, event_name): ...
362
363
class CancelFitException(Exception):
364
"""Exception to cancel training."""
365
366
class CancelEpochException(Exception):
367
"""Exception to cancel current epoch."""
368
369
class CancelTrainException(Exception):
370
"""Exception to cancel training phase."""
371
372
class CancelValidException(Exception):
373
"""Exception to cancel validation phase."""
374
375
class CancelBatchException(Exception):
376
"""Exception to cancel current batch."""
377
```
378
379
### Training Control and Debugging Callbacks
380
381
Advanced callbacks for training control, debugging, and model analysis.
382
383
```python { .api }
384
class TerminateOnNaNCallback(Callback):
385
"""
386
Automatically terminate training if loss becomes NaN or infinite.
387
Essential for robust training pipelines.
388
"""
389
order = -9
390
391
def after_batch(self):
392
"""Test if loss is NaN/inf and interrupt training."""
393
394
class ShortEpochCallback(Callback):
395
"""
396
Fit only a percentage of an epoch for debugging/testing.
397
398
Parameters:
399
- pct: Percentage of epoch to train (0.01 = 1%)
400
- short_valid: Whether to also shorten validation
401
"""
402
def __init__(self, pct=0.01, short_valid=True): ...
403
404
class CollectDataCallback(Callback):
405
"""
406
Collect all batches with predictions and losses for debugging.
407
Useful for analyzing model behavior and debugging issues.
408
"""
409
def before_fit(self): ...
410
def after_batch(self): ...
411
```
412
413
### Model Analysis and Hook Callbacks
414
415
Callbacks for analyzing model internals and registering hooks on model layers.
416
417
```python { .api }
418
class ActivationStats(HookCallback):
419
"""
420
Record activation statistics (mean, std, near-zero percentage) during training.
421
Essential for debugging vanishing/exploding gradients and dead neurons.
422
423
Parameters:
424
- with_hist: Whether to record activation histograms
425
"""
426
order = -20
427
428
def __init__(self, with_hist=False, **kwargs): ...
429
def layer_stats(self, idx): ...
430
def hist(self, idx): ...
431
def color_dim(self, idx, figsize=(10,5)): ...
432
def plot_layer_stats(self, idx): ...
433
434
class HookCallback(Callback):
435
"""
436
Base callback for registering hooks on model modules.
437
Foundation for advanced model introspection and analysis.
438
439
Parameters:
440
- modules: Specific modules to hook (None = all with params)
441
- every: Register hooks every N training iterations
442
- remove_end: Remove hooks after training
443
- is_forward: Forward vs backward hooks
444
- detach: Detach tensors from computation graph
445
- cpu: Move hooked data to CPU
446
- include_paramless: Include modules without parameters
447
"""
448
def __init__(self, modules=None, every=None, remove_end=True,
449
is_forward=True, detach=True, cpu=True,
450
include_paramless=False): ...
451
```
452
453
### RNN-Specific Callbacks
454
455
Specialized callbacks for training recurrent neural networks and sequence models.
456
457
```python { .api }
458
class ModelResetter(Callback):
459
"""
460
Reset RNN hidden states between training/validation phases.
461
Essential for proper RNN training with stateful hidden states.
462
"""
463
def before_train(self): ...
464
def before_validate(self): ...
465
def after_fit(self): ...
466
467
class RNNCallback(Callback):
468
"""
469
Handle RNN outputs and save raw/dropout outputs for regularization.
470
Manages the complexities of RNN training loops.
471
"""
472
def after_pred(self): ...
473
```
474
475
### Advanced Prediction and Uncertainty Callbacks
476
477
Callbacks for enhanced prediction gathering and uncertainty estimation.
478
479
```python { .api }
480
class MCDropoutCallback(Callback):
481
"""
482
Enable Monte Carlo Dropout for uncertainty estimation.
483
Keeps dropout layers active during validation for probabilistic predictions.
484
"""
485
def before_validate(self): ...
486
def after_validate(self): ...
487
488
class FetchPredsCallback(Callback):
489
"""
490
Fetch predictions during training loop with callback management.
491
492
Parameters:
493
- ds_idx: Dataset index (0=train, 1=valid)
494
- dl: Custom DataLoader for predictions
495
- with_decoded: Return decoded predictions
496
- cbs: Callbacks to temporarily remove
497
- reorder: Sort prediction results
498
"""
499
def __init__(self, ds_idx=1, dl=None, with_input=False,
500
with_decoded=False, cbs=None, reorder=True): ...
501
```
502
503
### Advanced Mixed Precision Training
504
505
Enhanced mixed precision training with fine-grained control over scaling and gradients.
506
507
```python { .api }
508
class NonNativeMixedPrecision(Callback):
509
"""
510
Manual mixed precision implementation for advanced control.
511
Provides more flexibility than PyTorch's native automatic mixed precision.
512
513
Parameters:
514
- loss_scale: Loss scaling factor for gradient stability
515
- flat_master: Flatten fp32 parameters for performance
516
- dynamic: Automatic loss scale adjustment
517
- max_loss_scale: Maximum loss scale value
518
- div_factor: Scale adjustment factor
519
- scale_wait: Batches to wait before scale increase
520
- clip: Gradient clipping value
521
"""
522
order = 10
523
524
def __init__(self, loss_scale=512, flat_master=False, dynamic=True,
525
max_loss_scale=2.**24, div_factor=2., scale_wait=500, clip=None): ...
526
```
527
528
### Integration and Production Callbacks
529
530
Callbacks for integration with external platforms and production workflows.
531
532
```python { .api }
533
class AzureMLCallback(Callback):
534
"""
535
Integration with Azure Machine Learning for experiment tracking.
536
Automatically logs metrics, parameters, and models to Azure ML.
537
538
Parameters:
539
- learn: Learner instance
540
- log_model: Whether to log the trained model
541
- model_name: Name for the logged model
542
"""
543
def __init__(self, learn=None, log_model=False, model_name='model'): ...
544
545
class CaptumInterpretation:
546
"""
547
Model interpretability using Facebook's Captum library.
548
Provides advanced attribution and visualization methods.
549
550
Parameters:
551
- learn: Learner instance
552
- cmap_name: Colormap name for visualizations
553
- methods: Visualization methods
554
- signs: Attribution signs to display
555
"""
556
def __init__(self, learn, cmap_name='custom blue', colors=None, N=256,
557
methods=('original_image', 'heat_map'), signs=("all", "positive")): ...
558
def visualize(self, inp, metric='IG', n_steps=1000, baseline_type='zeros'): ...
559
def insights(self, inp_data, debug=True): ...
560
```