0
# Core Training Components
1
2
Essential components for structuring deep learning training workflows in Lightning. These components provide the foundation for organized, scalable, and reproducible machine learning training.
3
4
## Capabilities
5
6
### Trainer
7
8
The main entry point for Lightning training that orchestrates the entire training process, handling distributed training, logging, checkpointing, and validation automatically.
9
10
```python { .api }
11
class Trainer:
12
def __init__(
13
self,
14
accelerator: str = "auto",
15
strategy: str = "auto",
16
devices: Union[List[int], str, int] = "auto",
17
num_nodes: int = 1,
18
precision: Union[str, int] = "32-true",
19
logger: Union[Logger, bool] = True,
20
callbacks: Optional[List[Callback]] = None,
21
fast_dev_run: Union[bool, int] = False,
22
max_epochs: Optional[int] = None,
23
min_epochs: Optional[int] = None,
24
max_steps: int = -1,
25
min_steps: Optional[int] = None,
26
max_time: Union[str, timedelta] = None,
27
limit_train_batches: Union[int, float] = 1.0,
28
limit_val_batches: Union[int, float] = 1.0,
29
limit_test_batches: Union[int, float] = 1.0,
30
limit_predict_batches: Union[int, float] = 1.0,
31
overfit_batches: Union[int, float] = 0.0,
32
val_check_interval: Union[int, float] = 1.0,
33
check_val_every_n_epoch: Optional[int] = 1,
34
num_sanity_val_steps: int = 2,
35
log_every_n_steps: int = 50,
36
enable_checkpointing: bool = True,
37
enable_progress_bar: bool = True,
38
enable_model_summary: bool = True,
39
accumulate_grad_batches: int = 1,
40
gradient_clip_val: Optional[float] = None,
41
gradient_clip_algorithm: Optional[str] = None,
42
deterministic: Optional[bool] = None,
43
benchmark: Optional[bool] = None,
44
inference_mode: bool = True,
45
use_distributed_sampler: bool = True,
46
profiler: Optional[Profiler] = None,
47
detect_anomaly: bool = False,
48
barebones: bool = False,
49
plugins: Optional[List[Any]] = None,
50
sync_batchnorm: bool = False,
51
reload_dataloaders_every_n_epochs: int = 0,
52
default_root_dir: Optional[str] = None,
53
**kwargs
54
):
55
"""
56
Initialize the Lightning Trainer.
57
58
Args:
59
accelerator: Hardware accelerator type ('cpu', 'gpu', 'tpu', 'auto')
60
strategy: Distributed training strategy ('ddp', 'fsdp', 'deepspeed', etc.)
61
devices: Which devices to use for training
62
num_nodes: Number of nodes for distributed training
63
precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
64
logger: Logger instance or True/False to enable/disable default logger
65
callbacks: List of callbacks to use during training
66
fast_dev_run: Run a single batch for debugging
67
max_epochs: Maximum number of epochs to train
68
min_epochs: Minimum number of epochs to train
69
max_steps: Maximum number of training steps
70
min_steps: Minimum number of training steps
71
max_time: Maximum training time
72
limit_train_batches: Limit training batches per epoch
73
limit_val_batches: Limit validation batches per epoch
74
val_check_interval: How often to check validation
75
enable_checkpointing: Enable automatic checkpointing
76
enable_progress_bar: Show progress bar during training
77
accumulate_grad_batches: Gradient accumulation steps
78
gradient_clip_val: Gradient clipping value
79
deterministic: Make training deterministic
80
profiler: Profiler for performance analysis
81
"""
82
83
def fit(
84
self,
85
model: LightningModule,
86
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
87
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
88
datamodule: Optional[LightningDataModule] = None,
89
ckpt_path: Optional[str] = None
90
) -> None:
91
"""
92
Fit the model with training and validation data.
93
94
Args:
95
model: LightningModule to train
96
train_dataloaders: Training data loaders
97
val_dataloaders: Validation data loaders
98
datamodule: LightningDataModule containing data loaders
99
ckpt_path: Path to checkpoint to resume from
100
"""
101
102
def validate(
103
self,
104
model: Optional[LightningModule] = None,
105
dataloaders: Optional[EVAL_DATALOADERS] = None,
106
ckpt_path: Optional[str] = None,
107
verbose: bool = True,
108
datamodule: Optional[LightningDataModule] = None
109
) -> List[Dict[str, float]]:
110
"""
111
Run validation on the model.
112
113
Args:
114
model: LightningModule to validate
115
dataloaders: Validation data loaders
116
ckpt_path: Path to checkpoint to load
117
verbose: Print validation results
118
datamodule: LightningDataModule containing data loaders
119
120
Returns:
121
List of validation metrics dictionaries
122
"""
123
124
def test(
125
self,
126
model: Optional[LightningModule] = None,
127
dataloaders: Optional[EVAL_DATALOADERS] = None,
128
ckpt_path: Optional[str] = None,
129
verbose: bool = True,
130
datamodule: Optional[LightningDataModule] = None
131
) -> List[Dict[str, float]]:
132
"""
133
Run testing on the model.
134
135
Args:
136
model: LightningModule to test
137
dataloaders: Test data loaders
138
ckpt_path: Path to checkpoint to load
139
verbose: Print test results
140
datamodule: LightningDataModule containing data loaders
141
142
Returns:
143
List of test metrics dictionaries
144
"""
145
146
def predict(
147
self,
148
model: Optional[LightningModule] = None,
149
dataloaders: Optional[EVAL_DATALOADERS] = None,
150
datamodule: Optional[LightningDataModule] = None,
151
return_predictions: Optional[bool] = None,
152
ckpt_path: Optional[str] = None
153
) -> Optional[List[Any]]:
154
"""
155
Run prediction on the model.
156
157
Args:
158
model: LightningModule to use for prediction
159
dataloaders: Prediction data loaders
160
datamodule: LightningDataModule containing data loaders
161
return_predictions: Whether to return predictions
162
ckpt_path: Path to checkpoint to load
163
164
Returns:
165
List of predictions if return_predictions=True
166
"""
167
168
def tune(
169
self,
170
model: LightningModule,
171
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
172
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
173
datamodule: Optional[LightningDataModule] = None,
174
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
175
lr_find_kwargs: Optional[Dict[str, Any]] = None
176
) -> Dict[str, Any]:
177
"""
178
Tune hyperparameters for the model.
179
180
Args:
181
model: LightningModule to tune
182
train_dataloaders: Training data loaders
183
val_dataloaders: Validation data loaders
184
datamodule: LightningDataModule containing data loaders
185
scale_batch_size_kwargs: Arguments for batch size scaling
186
lr_find_kwargs: Arguments for learning rate finding
187
188
Returns:
189
Dictionary with tuning results
190
"""
191
```
192
193
### LightningModule
194
195
Base class for organizing PyTorch code in Lightning. Defines model architecture, training logic, optimization, and provides hooks for the training lifecycle.
196
197
```python { .api }
198
class LightningModule(nn.Module):
199
def __init__(self):
200
"""Initialize the LightningModule."""
201
super().__init__()
202
203
def forward(self, *args, **kwargs) -> Any:
204
"""
205
Define the forward pass of the model.
206
207
Args:
208
*args: Positional arguments
209
**kwargs: Keyword arguments
210
211
Returns:
212
Model output
213
"""
214
215
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
216
"""
217
Define a single training step.
218
219
Args:
220
batch: Batch of training data
221
batch_idx: Index of the current batch
222
223
Returns:
224
Loss tensor or dictionary with 'loss' key
225
"""
226
227
def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
228
"""
229
Define a single validation step.
230
231
Args:
232
batch: Batch of validation data
233
batch_idx: Index of the current batch
234
235
Returns:
236
Optional loss tensor or metrics dictionary
237
"""
238
239
def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
240
"""
241
Define a single test step.
242
243
Args:
244
batch: Batch of test data
245
batch_idx: Index of the current batch
246
247
Returns:
248
Optional loss tensor or metrics dictionary
249
"""
250
251
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
252
"""
253
Define a single prediction step.
254
255
Args:
256
batch: Batch of prediction data
257
batch_idx: Index of the current batch
258
dataloader_idx: Index of the dataloader
259
260
Returns:
261
Model predictions
262
"""
263
264
def configure_optimizers(self) -> Union[Optimizer, Dict[str, Any]]:
265
"""
266
Configure optimizers and learning rate schedulers.
267
268
Returns:
269
Optimizer or dictionary with optimizer/scheduler configuration
270
"""
271
272
def configure_callbacks(self) -> Union[List[Callback], Callback]:
273
"""
274
Configure callbacks for this model.
275
276
Returns:
277
List of callbacks or single callback
278
"""
279
280
def log(
281
self,
282
name: str,
283
value: Any,
284
prog_bar: bool = False,
285
logger: bool = True,
286
on_step: Optional[bool] = None,
287
on_epoch: Optional[bool] = None,
288
reduce_fx: str = "mean",
289
enable_graph: bool = False,
290
sync_dist: bool = False,
291
sync_dist_group: Optional[Any] = None,
292
add_dataloader_idx: bool = True,
293
batch_size: Optional[int] = None,
294
metric_attribute: Optional[str] = None,
295
rank_zero_only: bool = False
296
) -> None:
297
"""
298
Log a key-value pair.
299
300
Args:
301
name: Name of the metric
302
value: Value to log
303
prog_bar: Show in progress bar
304
logger: Send to logger
305
on_step: Log at each step
306
on_epoch: Log at each epoch
307
reduce_fx: Reduction function for distributed training
308
sync_dist: Synchronize across distributed processes
309
batch_size: Current batch size for proper reduction
310
"""
311
312
def log_dict(
313
self,
314
dictionary: Dict[str, Any],
315
prog_bar: bool = False,
316
logger: bool = True,
317
on_step: Optional[bool] = None,
318
on_epoch: Optional[bool] = None,
319
reduce_fx: str = "mean",
320
enable_graph: bool = False,
321
sync_dist: bool = False,
322
sync_dist_group: Optional[Any] = None,
323
add_dataloader_idx: bool = True,
324
batch_size: Optional[int] = None,
325
rank_zero_only: bool = False
326
) -> None:
327
"""
328
Log a dictionary of key-value pairs.
329
330
Args:
331
dictionary: Dictionary of metrics to log
332
prog_bar: Show in progress bar
333
logger: Send to logger
334
on_step: Log at each step
335
on_epoch: Log at each epoch
336
reduce_fx: Reduction function for distributed training
337
sync_dist: Synchronize across distributed processes
338
batch_size: Current batch size for proper reduction
339
"""
340
```
341
342
### LightningDataModule
343
344
Encapsulates data loading logic including data downloading, preparation, splitting, and data loader creation. Provides a clean interface for data handling across train/val/test splits.
345
346
```python { .api }
347
class LightningDataModule:
348
def __init__(self):
349
"""Initialize the LightningDataModule."""
350
351
def prepare_data(self) -> None:
352
"""
353
Download and prepare data. Called only on rank 0.
354
Use this for data download, preprocessing that shouldn't be done on every device.
355
"""
356
357
def setup(self, stage: str) -> None:
358
"""
359
Set up datasets for each stage.
360
361
Args:
362
stage: 'fit', 'validate', 'test', or 'predict'
363
"""
364
365
def train_dataloader(self) -> TRAIN_DATALOADERS:
366
"""
367
Create training data loader.
368
369
Returns:
370
Training data loader(s)
371
"""
372
373
def val_dataloader(self) -> EVAL_DATALOADERS:
374
"""
375
Create validation data loader.
376
377
Returns:
378
Validation data loader(s)
379
"""
380
381
def test_dataloader(self) -> EVAL_DATALOADERS:
382
"""
383
Create test data loader.
384
385
Returns:
386
Test data loader(s)
387
"""
388
389
def predict_dataloader(self) -> EVAL_DATALOADERS:
390
"""
391
Create prediction data loader.
392
393
Returns:
394
Prediction data loader(s)
395
"""
396
397
def teardown(self, stage: str) -> None:
398
"""
399
Clean up after training/testing.
400
401
Args:
402
stage: 'fit', 'validate', 'test', or 'predict'
403
"""
404
405
def state_dict(self) -> Dict[str, Any]:
406
"""
407
Called when saving a checkpoint.
408
409
Returns:
410
Dictionary of state to save
411
"""
412
413
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
414
"""
415
Called when loading a checkpoint.
416
417
Args:
418
state_dict: Dictionary of saved state
419
"""
420
```
421
422
### Callback
423
424
Base class for creating custom callbacks to hook into the training lifecycle. Callbacks provide a way to add functionality at specific points during training.
425
426
```python { .api }
427
class Callback:
428
def __init__(self):
429
"""Initialize the callback."""
430
431
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
432
"""Called when training begins."""
433
434
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
435
"""Called when training ends."""
436
437
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
438
"""Called when validation begins."""
439
440
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
441
"""Called when validation ends."""
442
443
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
444
"""Called when testing begins."""
445
446
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
447
"""Called when testing ends."""
448
449
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
450
"""Called when an epoch begins."""
451
452
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
453
"""Called when an epoch ends."""
454
455
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
456
"""Called when a training epoch begins."""
457
458
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
459
"""Called when a training epoch ends."""
460
461
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
462
"""Called when a validation epoch begins."""
463
464
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
465
"""Called when a validation epoch ends."""
466
467
def on_train_batch_start(
468
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
469
) -> None:
470
"""Called when a training batch begins."""
471
472
def on_train_batch_end(
473
self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
474
) -> None:
475
"""Called when a training batch ends."""
476
477
def on_validation_batch_start(
478
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
479
) -> None:
480
"""Called when a validation batch begins."""
481
482
def on_validation_batch_end(
483
self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0
484
) -> None:
485
"""Called when a validation batch ends."""
486
487
def on_before_optimizer_step(
488
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int
489
) -> None:
490
"""Called before optimizer step."""
491
492
def on_before_zero_grad(
493
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
494
) -> None:
495
"""Called before gradients are zeroed."""
496
497
def on_save_checkpoint(
498
self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
499
) -> None:
500
"""Called when saving a checkpoint."""
501
502
def on_load_checkpoint(
503
self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
504
) -> None:
505
"""Called when loading a checkpoint."""
506
507
def state_dict(self) -> Dict[str, Any]:
508
"""
509
Called when saving a checkpoint.
510
511
Returns:
512
Dictionary of callback state to save
513
"""
514
515
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
516
"""
517
Called when loading a checkpoint.
518
519
Args:
520
state_dict: Dictionary of saved callback state
521
"""
522
```