0
# Callbacks and Lifecycle Hooks
1
2
Comprehensive callback system for training lifecycle management including checkpointing, early stopping, learning rate scheduling, monitoring, and optimization callbacks. Callbacks provide a clean way to add functionality without modifying the core training loop.
3
4
## Capabilities
5
6
### Model Checkpointing
7
8
Automatically save model checkpoints during training based on monitored metrics, with support for saving top-k models and automatic cleanup.
9
10
```python { .api }
11
class ModelCheckpoint(Callback):
12
def __init__(
13
self,
14
dirpath: Optional[str] = None,
15
filename: Optional[str] = None,
16
monitor: Optional[str] = None,
17
verbose: bool = False,
18
save_last: Optional[bool] = None,
19
save_top_k: int = 1,
20
save_weights_only: bool = False,
21
mode: str = "min",
22
auto_insert_metric_name: bool = True,
23
every_n_train_steps: Optional[int] = None,
24
train_time_interval: Optional[timedelta] = None,
25
every_n_epochs: Optional[int] = None,
26
save_on_train_epoch_end: Optional[bool] = None,
27
enable_version_counter: bool = True
28
):
29
"""
30
Initialize ModelCheckpoint callback.
31
32
Args:
33
dirpath: Directory to save checkpoints
34
filename: Checkpoint filename pattern
35
monitor: Metric to monitor for saving best models
36
verbose: Print checkpoint saving messages
37
save_last: Always save the last checkpoint
38
save_top_k: Number of best models to save
39
save_weights_only: Save only model weights
40
mode: 'min' or 'max' for monitored metric
41
auto_insert_metric_name: Insert metric name in filename
42
every_n_train_steps: Save every N training steps
43
train_time_interval: Save every time interval
44
every_n_epochs: Save every N epochs
45
save_on_train_epoch_end: Save at end of training epoch
46
enable_version_counter: Enable version counter in filename
47
"""
48
49
@property
50
def best_model_path(self) -> str:
51
"""Path to the best saved model."""
52
53
@property
54
def best_model_score(self) -> Optional[float]:
55
"""Score of the best saved model."""
56
57
@property
58
def last_model_path(self) -> str:
59
"""Path to the last saved model."""
60
```
61
62
### Early Stopping
63
64
Stop training when a monitored metric stops improving, with configurable patience and thresholds to prevent overfitting.
65
66
```python { .api }
67
class EarlyStopping(Callback):
68
def __init__(
69
self,
70
monitor: str,
71
min_delta: float = 0.0,
72
patience: int = 3,
73
verbose: bool = False,
74
mode: str = "min",
75
strict: bool = True,
76
check_finite: bool = True,
77
stopping_threshold: Optional[float] = None,
78
divergence_threshold: Optional[float] = None,
79
check_on_train_epoch_end: Optional[bool] = None,
80
log_rank_zero_only: bool = False
81
):
82
"""
83
Initialize EarlyStopping callback.
84
85
Args:
86
monitor: Metric to monitor
87
min_delta: Minimum change to qualify as improvement
88
patience: Number of epochs with no improvement to wait
89
verbose: Print early stopping messages
90
mode: 'min' or 'max' for monitored metric
91
strict: Raise error if monitored metric is not found
92
check_finite: Stop if monitored metric is not finite
93
stopping_threshold: Stop when metric reaches this threshold
94
divergence_threshold: Stop if metric diverges beyond this
95
check_on_train_epoch_end: Check metric at end of training epoch
96
log_rank_zero_only: Log only on rank 0
97
"""
98
99
@property
100
def wait_count(self) -> int:
101
"""Number of epochs waited since last improvement."""
102
103
@property
104
def best_score(self) -> Optional[float]:
105
"""Best score achieved."""
106
107
@property
108
def stopped_epoch(self) -> int:
109
"""Epoch when training was stopped."""
110
```
111
112
### Learning Rate Monitoring
113
114
Monitor and log learning rate changes during training, supporting multiple optimizers and schedulers.
115
116
```python { .api }
117
class LearningRateMonitor(Callback):
118
def __init__(
119
self,
120
logging_interval: str = "epoch",
121
log_momentum: bool = False,
122
log_weight_decay: bool = False
123
):
124
"""
125
Initialize LearningRateMonitor callback.
126
127
Args:
128
logging_interval: 'step' or 'epoch' for logging frequency
129
log_momentum: Also log momentum values
130
log_weight_decay: Also log weight decay values
131
"""
132
```
133
134
### Stochastic Weight Averaging
135
136
Implement stochastic weight averaging to improve model generalization by averaging weights from multiple epochs.
137
138
```python { .api }
139
class StochasticWeightAveraging(Callback):
140
def __init__(
141
self,
142
swa_lrs: Union[float, List[float]],
143
swa_epoch_start: Union[int, float] = 0.8,
144
annealing_epochs: int = 10,
145
annealing_strategy: str = "cos",
146
avg_fn: Optional[Callable] = None,
147
device: Optional[Union[torch.device, str]] = None
148
):
149
"""
150
Initialize StochasticWeightAveraging callback.
151
152
Args:
153
swa_lrs: Learning rate(s) for SWA
154
swa_epoch_start: Epoch to start SWA (int or fraction)
155
annealing_epochs: Number of epochs for annealing
156
annealing_strategy: 'linear' or 'cos' annealing
157
avg_fn: Custom averaging function
158
device: Device for SWA model
159
"""
160
```
161
162
### Progress Bars
163
164
Visual progress indicators during training with customizable display options and rich formatting support.
165
166
```python { .api }
167
class TQDMProgressBar(Callback):
168
def __init__(
169
self,
170
refresh_rate: int = 1,
171
process_position: int = 0
172
):
173
"""
174
Initialize TQDM progress bar.
175
176
Args:
177
refresh_rate: Progress bar refresh rate
178
process_position: Position for multiple progress bars
179
"""
180
181
class RichProgressBar(Callback):
182
def __init__(
183
self,
184
refresh_rate: int = 1,
185
leave: bool = False,
186
theme: RichProgressBarTheme = RichProgressBarTheme(),
187
console_kwargs: Optional[Dict[str, Any]] = None
188
):
189
"""
190
Initialize Rich progress bar with enhanced formatting.
191
192
Args:
193
refresh_rate: Progress bar refresh rate
194
leave: Keep progress bar after completion
195
theme: Rich theme configuration
196
console_kwargs: Additional console arguments
197
"""
198
199
class ProgressBar(Callback):
200
def __init__(self):
201
"""Base progress bar callback."""
202
203
def disable(self) -> None:
204
"""Disable the progress bar."""
205
206
def enable(self) -> None:
207
"""Enable the progress bar."""
208
```
209
210
### Model Summary Display
211
212
Display detailed model architecture information including layer types, parameters, and memory usage.
213
214
```python { .api }
215
class ModelSummary(Callback):
216
def __init__(self, max_depth: int = 1):
217
"""
218
Initialize ModelSummary callback.
219
220
Args:
221
max_depth: Maximum depth for nested modules
222
"""
223
224
class RichModelSummary(Callback):
225
def __init__(self, max_depth: int = 1):
226
"""
227
Initialize RichModelSummary with enhanced formatting.
228
229
Args:
230
max_depth: Maximum depth for nested modules
231
"""
232
```
233
234
### Hyperparameter Optimization
235
236
Callbacks for automated hyperparameter tuning including batch size finding and learning rate finding.
237
238
```python { .api }
239
class BatchSizeFinder(Callback):
240
def __init__(
241
self,
242
mode: str = "power",
243
steps_per_trial: int = 3,
244
init_val: int = 2,
245
max_trials: int = 25,
246
batch_arg_name: str = "batch_size"
247
):
248
"""
249
Initialize BatchSizeFinder callback.
250
251
Args:
252
mode: 'power' or 'binsearch' for search strategy
253
steps_per_trial: Steps per batch size trial
254
init_val: Initial batch size
255
max_trials: Maximum number of trials
256
batch_arg_name: Argument name for batch size
257
"""
258
259
class LearningRateFinder(Callback):
260
def __init__(
261
self,
262
min_lr: float = 1e-8,
263
max_lr: float = 1.0,
264
num_training: int = 100,
265
mode: str = "exponential",
266
early_stop_threshold: float = 4.0,
267
update_attr: bool = False
268
):
269
"""
270
Initialize LearningRateFinder callback.
271
272
Args:
273
min_lr: Minimum learning rate
274
max_lr: Maximum learning rate
275
num_training: Number of training steps
276
mode: 'exponential' or 'linear' search
277
early_stop_threshold: Threshold for early stopping
278
update_attr: Update model's learning rate attribute
279
"""
280
```
281
282
### Fine-tuning Callbacks
283
284
Specialized callbacks for transfer learning and progressive fine-tuning strategies.
285
286
```python { .api }
287
class BaseFinetuning(Callback):
288
def __init__(self, unfreeze_at_epoch: int = 10, lambda_func: Optional[Callable] = None):
289
"""
290
Base class for fine-tuning callbacks.
291
292
Args:
293
unfreeze_at_epoch: Epoch to unfreeze parameters
294
lambda_func: Function to determine learning rates
295
"""
296
297
def freeze_before_training(self, pl_module: LightningModule) -> None:
298
"""Freeze parameters before training starts."""
299
300
def finetune_function(
301
self,
302
pl_module: LightningModule,
303
current_epoch: int,
304
optimizer: Optimizer,
305
optimizer_idx: int
306
) -> None:
307
"""Function called during fine-tuning."""
308
309
class BackboneFinetuning(BaseFinetuning):
310
def __init__(
311
self,
312
unfreeze_backbone_at_epoch: int = 10,
313
lambda_func: Optional[Callable] = None,
314
backbone_initial_ratio_lr: float = 0.1,
315
backbone_initial_lr: Optional[float] = None,
316
should_align: bool = True,
317
initial_denom_lr: float = 10.0,
318
train_bn: bool = True
319
):
320
"""
321
Fine-tuning callback for backbone networks.
322
323
Args:
324
unfreeze_backbone_at_epoch: Epoch to unfreeze backbone
325
lambda_func: Learning rate scheduling function
326
backbone_initial_ratio_lr: Initial backbone LR ratio
327
backbone_initial_lr: Initial backbone learning rate
328
should_align: Align learning rates
329
initial_denom_lr: Initial denominator for LR calculation
330
train_bn: Train batch normalization layers
331
"""
332
```
333
334
### Performance Monitoring
335
336
Callbacks for monitoring training performance, throughput, and resource utilization.
337
338
```python { .api }
339
class ThroughputMonitor(Callback):
340
def __init__(
341
self,
342
length_key: str = "seq_len",
343
batch_size_key: str = "batch_size",
344
window_size: int = 100
345
):
346
"""
347
Initialize ThroughputMonitor callback.
348
349
Args:
350
length_key: Key for sequence length in batch
351
batch_size_key: Key for batch size
352
window_size: Window size for throughput calculation
353
"""
354
355
class DeviceStatsMonitor(Callback):
356
def __init__(self, cpu_stats: Optional[bool] = None):
357
"""
358
Initialize DeviceStatsMonitor callback.
359
360
Args:
361
cpu_stats: Monitor CPU statistics
362
"""
363
364
class Timer(Callback):
365
def __init__(self, duration: Optional[Union[str, timedelta]] = None, interval: str = "step"):
366
"""
367
Initialize Timer callback for training duration control.
368
369
Args:
370
duration: Maximum training duration
371
interval: 'step' or 'epoch' for timing
372
"""
373
```
374
375
### Custom Callback Creation
376
377
```python { .api }
378
class LambdaCallback(Callback):
379
def __init__(self, **kwargs):
380
"""
381
Create callback from lambda functions.
382
383
Args:
384
**kwargs: Mapping of hook names to functions
385
"""
386
```
387
388
## Usage Examples
389
390
### Basic Callback Usage
391
392
```python
393
from lightning import Trainer
394
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
395
396
# Configure callbacks
397
checkpoint_callback = ModelCheckpoint(
398
monitor='val_loss',
399
dirpath='./checkpoints',
400
filename='model-{epoch:02d}-{val_loss:.2f}',
401
save_top_k=3,
402
mode='min'
403
)
404
405
early_stopping = EarlyStopping(
406
monitor='val_loss',
407
patience=5,
408
mode='min'
409
)
410
411
# Use callbacks in trainer
412
trainer = Trainer(
413
callbacks=[checkpoint_callback, early_stopping],
414
max_epochs=100
415
)
416
```
417
418
### Custom Callback Example
419
420
```python
421
import lightning as L
422
423
class MetricLoggingCallback(L.Callback):
424
def on_train_epoch_end(self, trainer, pl_module):
425
# Log custom metrics at end of each epoch
426
metrics = trainer.callback_metrics
427
epoch = trainer.current_epoch
428
429
# Custom logging logic
430
if 'train_loss' in metrics:
431
print(f"Epoch {epoch}: Train Loss = {metrics['train_loss']:.4f}")
432
433
# Save metrics to file
434
with open('metrics.log', 'a') as f:
435
f.write(f"Epoch {epoch}: {dict(metrics)}\n")
436
437
# Use custom callback
438
trainer = Trainer(callbacks=[MetricLoggingCallback()])
439
```