0
# Training and Model Organization
1
2
PyTorch Lightning components for organizing training code, managing experiments, and scaling across devices. This module provides the main training orchestrator, base classes for models and data, and the callback system for extending functionality.
3
4
## Capabilities
5
6
### Trainer
7
8
The central orchestrator that automates the training loop, handles device management, logging, checkpointing, and validation. Supports distributed training across multiple GPUs, TPUs, and nodes.
9
10
```python { .api }
11
class Trainer:
12
def __init__(
13
self,
14
logger: Union[Logger, Iterable[Logger], bool] = True,
15
enable_checkpointing: bool = True,
16
callbacks: Optional[Union[List[Callback], Callback]] = None,
17
default_root_dir: Optional[str] = None,
18
gradient_clip_val: Optional[Union[int, float]] = None,
19
gradient_clip_algorithm: Optional[str] = None,
20
num_nodes: int = 1,
21
devices: Optional[Union[List[int], str, int]] = None,
22
enable_progress_bar: bool = True,
23
overfit_batches: Union[int, float] = 0.0,
24
track_grad_norm: Union[int, float, str] = -1,
25
check_val_every_n_epoch: Optional[int] = 1,
26
val_check_interval: Union[int, float] = 1.0,
27
log_every_n_steps: int = 50,
28
accelerator: Optional[str] = None,
29
strategy: Optional[str] = None,
30
sync_batchnorm: bool = False,
31
precision: Optional[Union[int, str]] = None,
32
enable_model_summary: bool = True,
33
max_epochs: Optional[int] = None,
34
min_epochs: Optional[int] = None,
35
max_steps: int = -1,
36
min_steps: Optional[int] = None,
37
max_time: Optional[Union[str, timedelta]] = None,
38
limit_train_batches: Optional[Union[int, float]] = None,
39
limit_val_batches: Optional[Union[int, float]] = None,
40
limit_test_batches: Optional[Union[int, float]] = None,
41
limit_predict_batches: Optional[Union[int, float]] = None,
42
fast_dev_run: Union[int, bool] = False,
43
accumulate_grad_batches: int = 1,
44
profiler: Optional[Union[str, Profiler]] = None,
45
benchmark: Optional[bool] = None,
46
deterministic: Optional[Union[bool, str]] = None,
47
reload_dataloaders_every_n_epochs: int = 0,
48
auto_lr_find: Union[bool, str] = False,
49
replace_sampler_ddp: bool = True,
50
detect_anomaly: bool = False,
51
auto_scale_batch_size: Union[str, bool] = False,
52
plugins: Optional[Union[str, list]] = None,
53
move_metrics_to_cpu: bool = False,
54
multiple_trainloader_mode: str = "max_size_cycle",
55
inference_mode: bool = True,
56
use_distributed_sampler: bool = True,
57
barebones: bool = False,
58
**kwargs
59
):
60
"""
61
Lightning Trainer for automating the training process.
62
63
Parameters:
64
- logger: Logger instance or list of loggers, or True for default TensorBoard logger
65
- enable_checkpointing: Enable automatic model checkpointing
66
- callbacks: Callback instances to customize training behavior
67
- default_root_dir: Default directory for logs and checkpoints
68
- gradient_clip_val: Gradient clipping value (0 means no clipping)
69
- gradient_clip_algorithm: Gradient clipping algorithm ('value' or 'norm')
70
- num_nodes: Number of nodes for distributed training
71
- devices: Device specification (int, list, or 'auto')
72
- enable_progress_bar: Show progress bar during training
73
- overfit_batches: Overfit on a subset of data for debugging
74
- track_grad_norm: Track gradient norms (int for L-norm, -1 to disable)
75
- check_val_every_n_epoch: Run validation every N epochs
76
- val_check_interval: Validation frequency within an epoch
77
- log_every_n_steps: Log metrics every N training steps
78
- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
79
- strategy: Training strategy for distributed training
80
- sync_batchnorm: Synchronize batch norm across devices
81
- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')
82
- enable_model_summary: Print model summary at training start
83
- max_epochs: Maximum number of epochs to train
84
- min_epochs: Minimum number of epochs to train
85
- max_steps: Maximum number of training steps
86
- min_steps: Minimum number of training steps
87
- max_time: Maximum training time
88
- limit_train_batches: Limit training batches per epoch
89
- limit_val_batches: Limit validation batches
90
- limit_test_batches: Limit test batches
91
- limit_predict_batches: Limit prediction batches
92
- fast_dev_run: Quick development run with limited batches
93
- accumulate_grad_batches: Accumulate gradients over N batches
94
- profiler: Profiler for performance analysis
95
- benchmark: Enable cuDNN benchmarking for consistent input sizes
96
- deterministic: Enable deterministic training (may impact performance)
97
- reload_dataloaders_every_n_epochs: Reload dataloaders periodically
98
- auto_lr_find: Automatically find optimal learning rate
99
- replace_sampler_ddp: Replace sampler with DistributedSampler for DDP
100
- detect_anomaly: Enable anomaly detection for debugging
101
- auto_scale_batch_size: Automatically scale batch size
102
- plugins: Additional plugins for custom functionality
103
- move_metrics_to_cpu: Move metrics to CPU to save GPU memory
104
- multiple_trainloader_mode: Mode for handling multiple train dataloaders
105
- inference_mode: Use inference mode during validation/test/predict
106
- use_distributed_sampler: Use distributed sampler in DDP
107
- barebones: Minimal trainer setup for maximum performance
108
"""
109
110
def fit(
111
self,
112
model: LightningModule,
113
train_dataloaders: Optional[DataLoader] = None,
114
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
115
datamodule: Optional[LightningDataModule] = None,
116
ckpt_path: Optional[str] = None
117
):
118
"""
119
Train the model.
120
121
Parameters:
122
- model: LightningModule to train
123
- train_dataloaders: Training dataloader(s)
124
- val_dataloaders: Validation dataloader(s)
125
- datamodule: LightningDataModule containing dataloaders
126
- ckpt_path: Path to checkpoint to resume training from
127
"""
128
129
def validate(
130
self,
131
model: Optional[LightningModule] = None,
132
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
133
ckpt_path: Optional[str] = None,
134
verbose: bool = True,
135
datamodule: Optional[LightningDataModule] = None
136
):
137
"""
138
Run validation loop.
139
140
Parameters:
141
- model: LightningModule to validate
142
- dataloaders: Validation dataloader(s)
143
- ckpt_path: Path to checkpoint to load
144
- verbose: Print validation results
145
- datamodule: LightningDataModule containing dataloaders
146
147
Returns:
148
List of validation results
149
"""
150
151
def test(
152
self,
153
model: Optional[LightningModule] = None,
154
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
155
ckpt_path: Optional[str] = None,
156
verbose: bool = True,
157
datamodule: Optional[LightningDataModule] = None
158
):
159
"""
160
Run test loop.
161
162
Parameters:
163
- model: LightningModule to test
164
- dataloaders: Test dataloader(s)
165
- ckpt_path: Path to checkpoint to load
166
- verbose: Print test results
167
- datamodule: LightningDataModule containing dataloaders
168
169
Returns:
170
List of test results
171
"""
172
173
def predict(
174
self,
175
model: Optional[LightningModule] = None,
176
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
177
datamodule: Optional[LightningDataModule] = None,
178
return_predictions: Optional[bool] = None,
179
ckpt_path: Optional[str] = None
180
):
181
"""
182
Run prediction loop.
183
184
Parameters:
185
- model: LightningModule for predictions
186
- dataloaders: Prediction dataloader(s)
187
- datamodule: LightningDataModule containing dataloaders
188
- return_predictions: Return predictions in memory
189
- ckpt_path: Path to checkpoint to load
190
191
Returns:
192
List of predictions if return_predictions=True
193
"""
194
195
def tune(
196
self,
197
model: LightningModule,
198
train_dataloaders: Optional[DataLoader] = None,
199
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
200
datamodule: Optional[LightningDataModule] = None,
201
scale_batch_size_kwargs: Optional[dict] = None,
202
lr_find_kwargs: Optional[dict] = None
203
):
204
"""
205
Auto-tune model hyperparameters.
206
207
Parameters:
208
- model: LightningModule to tune
209
- train_dataloaders: Training dataloader(s)
210
- val_dataloaders: Validation dataloader(s)
211
- datamodule: LightningDataModule containing dataloaders
212
- scale_batch_size_kwargs: Arguments for batch size scaling
213
- lr_find_kwargs: Arguments for learning rate finding
214
215
Returns:
216
Tuning results
217
"""
218
```
219
220
### LightningModule
221
222
Base class for organizing PyTorch model code with standardized hooks for training, validation, testing, and prediction. Handles optimizer configuration and provides extensive customization points.
223
224
```python { .api }
225
class LightningModule:
226
def __init__(self):
227
"""Base class for organizing PyTorch model logic."""
228
229
def forward(self, *args, **kwargs):
230
"""
231
Define the forward pass of the model.
232
233
Returns:
234
Model predictions
235
"""
236
237
def training_step(self, batch, batch_idx: int):
238
"""
239
Define training step logic.
240
241
Parameters:
242
- batch: Training batch data
243
- batch_idx: Index of the current batch
244
245
Returns:
246
Training loss (torch.Tensor) or dict with 'loss' key
247
"""
248
249
def validation_step(self, batch, batch_idx: int):
250
"""
251
Define validation step logic.
252
253
Parameters:
254
- batch: Validation batch data
255
- batch_idx: Index of the current batch
256
257
Returns:
258
Validation outputs (optional)
259
"""
260
261
def test_step(self, batch, batch_idx: int):
262
"""
263
Define test step logic.
264
265
Parameters:
266
- batch: Test batch data
267
- batch_idx: Index of the current batch
268
269
Returns:
270
Test outputs (optional)
271
"""
272
273
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
274
"""
275
Define prediction step logic.
276
277
Parameters:
278
- batch: Prediction batch data
279
- batch_idx: Index of the current batch
280
- dataloader_idx: Index of the current dataloader
281
282
Returns:
283
Predictions
284
"""
285
286
def configure_optimizers(self):
287
"""
288
Configure optimizers and learning rate schedulers.
289
290
Returns:
291
Optimizer, list of optimizers, or dict with optimizer/scheduler config
292
"""
293
294
def configure_callbacks(self):
295
"""
296
Configure model-specific callbacks.
297
298
Returns:
299
List of callback instances
300
"""
301
302
def log(self, name: str, value, prog_bar: bool = False, logger: bool = True,
303
on_step: bool = None, on_epoch: bool = None, reduce_fx: str = "mean",
304
enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: str = None,
305
add_dataloader_idx: bool = True, batch_size: int = None, metric_attribute: str = None,
306
rank_zero_only: bool = False):
307
"""
308
Log metrics during training.
309
310
Parameters:
311
- name: Metric name
312
- value: Metric value
313
- prog_bar: Show in progress bar
314
- logger: Send to logger
315
- on_step: Log at current step
316
- on_epoch: Log at epoch end
317
- reduce_fx: Reduction function for distributed training
318
- enable_graph: Keep computation graph
319
- sync_dist: Synchronize across distributed processes
320
- sync_dist_group: Process group for synchronization
321
- add_dataloader_idx: Add dataloader index to metric name
322
- batch_size: Batch size for proper averaging
323
- metric_attribute: Attribute name for storing metric
324
- rank_zero_only: Log only on rank 0
325
"""
326
```
327
328
### LightningDataModule
329
330
Base class for organizing data loading logic, providing a clean interface for data preparation, dataset setup, and dataloader creation across different stages of training.
331
332
```python { .api }
333
class LightningDataModule:
334
def __init__(self, *args, **kwargs):
335
"""Base class for organizing data loading logic."""
336
337
def setup(self, stage: str = None):
338
"""
339
Setup datasets for different stages.
340
341
Parameters:
342
- stage: Current stage ('fit', 'validate', 'test', 'predict')
343
"""
344
345
def prepare_data(self):
346
"""
347
Download and prepare data (called once per node).
348
Use this for data downloading, tokenization, etc.
349
"""
350
351
def train_dataloader(self):
352
"""
353
Create training dataloader.
354
355
Returns:
356
DataLoader for training
357
"""
358
359
def val_dataloader(self):
360
"""
361
Create validation dataloader.
362
363
Returns:
364
DataLoader or list of DataLoaders for validation
365
"""
366
367
def test_dataloader(self):
368
"""
369
Create test dataloader.
370
371
Returns:
372
DataLoader or list of DataLoaders for testing
373
"""
374
375
def predict_dataloader(self):
376
"""
377
Create prediction dataloader.
378
379
Returns:
380
DataLoader or list of DataLoaders for prediction
381
"""
382
383
def teardown(self, stage: str = None):
384
"""
385
Clean up after training/testing.
386
387
Parameters:
388
- stage: Current stage ('fit', 'validate', 'test', 'predict')
389
"""
390
```
391
392
### Callback System
393
394
Base class for creating custom training callbacks that can hook into different stages of the training process to extend functionality.
395
396
```python { .api }
397
class Callback:
398
def __init__(self):
399
"""Base class for creating training callbacks."""
400
401
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
402
"""Called when training begins."""
403
404
def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
405
"""Called when training ends."""
406
407
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
408
"""Called at the beginning of each epoch."""
409
410
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
411
"""Called at the end of each epoch."""
412
413
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
414
"""Called at the beginning of each training epoch."""
415
416
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
417
"""Called at the end of each training epoch."""
418
419
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
420
"""Called at the beginning of each validation epoch."""
421
422
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
423
"""Called at the end of each validation epoch."""
424
425
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
426
"""Called at the beginning of each test epoch."""
427
428
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
429
"""Called at the end of each test epoch."""
430
431
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
432
"""Called before each training batch."""
433
434
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
435
"""Called after each training batch."""
436
437
def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
438
"""Called before each validation batch."""
439
440
def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
441
"""Called after each validation batch."""
442
443
def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
444
"""Called before each test batch."""
445
446
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
447
"""Called after each test batch."""
448
```
449
450
## Usage Examples
451
452
### Basic Training Setup
453
454
```python
455
import lightning as L
456
import torch
457
import torch.nn as nn
458
from torch.utils.data import DataLoader, Dataset
459
460
class MyDataset(Dataset):
461
def __init__(self, size=1000):
462
self.size = size
463
464
def __len__(self):
465
return self.size
466
467
def __getitem__(self, idx):
468
return torch.randn(10), torch.randn(1)
469
470
class MyModel(L.LightningModule):
471
def __init__(self):
472
super().__init__()
473
self.layer = nn.Linear(10, 1)
474
475
def forward(self, x):
476
return self.layer(x)
477
478
def training_step(self, batch, batch_idx):
479
x, y = batch
480
y_hat = self(x)
481
loss = nn.functional.mse_loss(y_hat, y)
482
self.log('train_loss', loss)
483
return loss
484
485
def configure_optimizers(self):
486
return torch.optim.Adam(self.parameters(), lr=0.02)
487
488
# Training
489
model = MyModel()
490
trainer = L.Trainer(max_epochs=3)
491
train_loader = DataLoader(MyDataset(), batch_size=32)
492
trainer.fit(model, train_loader)
493
```
494
495
### Using DataModule
496
497
```python
498
class MyDataModule(L.LightningDataModule):
499
def __init__(self, batch_size=32):
500
super().__init__()
501
self.batch_size = batch_size
502
503
def setup(self, stage=None):
504
if stage == 'fit':
505
self.train_dataset = MyDataset(size=800)
506
self.val_dataset = MyDataset(size=200)
507
elif stage == 'test':
508
self.test_dataset = MyDataset(size=100)
509
510
def train_dataloader(self):
511
return DataLoader(self.train_dataset, batch_size=self.batch_size)
512
513
def val_dataloader(self):
514
return DataLoader(self.val_dataset, batch_size=self.batch_size)
515
516
def test_dataloader(self):
517
return DataLoader(self.test_dataset, batch_size=self.batch_size)
518
519
# Training with DataModule
520
model = MyModel()
521
datamodule = MyDataModule(batch_size=64)
522
trainer = L.Trainer(max_epochs=3)
523
trainer.fit(model, datamodule=datamodule)
524
trainer.test(model, datamodule=datamodule)
525
```