Unified deep learning framework integrating PyTorch Lightning, Lightning Fabric, and Lightning Apps for training, deploying, and shipping AI products.
—
PyTorch Lightning components for organizing training code, managing experiments, and scaling across devices. This module provides the main training orchestrator, base classes for models and data, and the callback system for extending functionality.
The central orchestrator that automates the training loop, handles device management, logging, checkpointing, and validation. Supports distributed training across multiple GPUs, TPUs, and nodes.
class Trainer:
def __init__(
self,
logger: Union[Logger, Iterable[Logger], bool] = True,
enable_checkpointing: bool = True,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
num_nodes: int = 1,
devices: Optional[Union[List[int], str, int]] = None,
enable_progress_bar: bool = True,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: Optional[int] = 1,
val_check_interval: Union[int, float] = 1.0,
log_every_n_steps: int = 50,
accelerator: Optional[str] = None,
strategy: Optional[str] = None,
sync_batchnorm: bool = False,
precision: Optional[Union[int, str]] = None,
enable_model_summary: bool = True,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: int = -1,
min_steps: Optional[int] = None,
max_time: Optional[Union[str, timedelta]] = None,
limit_train_batches: Optional[Union[int, float]] = None,
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: int = 1,
profiler: Optional[Union[str, Profiler]] = None,
benchmark: Optional[bool] = None,
deterministic: Optional[Union[bool, str]] = None,
reload_dataloaders_every_n_epochs: int = 0,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
detect_anomaly: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
plugins: Optional[Union[str, list]] = None,
move_metrics_to_cpu: bool = False,
multiple_trainloader_mode: str = "max_size_cycle",
inference_mode: bool = True,
use_distributed_sampler: bool = True,
barebones: bool = False,
**kwargs
):
"""
Lightning Trainer for automating the training process.
Parameters:
- logger: Logger instance or list of loggers, or True for default TensorBoard logger
- enable_checkpointing: Enable automatic model checkpointing
- callbacks: Callback instances to customize training behavior
- default_root_dir: Default directory for logs and checkpoints
- gradient_clip_val: Gradient clipping value (0 means no clipping)
- gradient_clip_algorithm: Gradient clipping algorithm ('value' or 'norm')
- num_nodes: Number of nodes for distributed training
- devices: Device specification (int, list, or 'auto')
- enable_progress_bar: Show progress bar during training
- overfit_batches: Overfit on a subset of data for debugging
- track_grad_norm: Track gradient norms (int for L-norm, -1 to disable)
- check_val_every_n_epoch: Run validation every N epochs
- val_check_interval: Validation frequency within an epoch
- log_every_n_steps: Log metrics every N training steps
- accelerator: Hardware accelerator ('cpu', 'gpu', 'tpu', 'auto')
- strategy: Training strategy for distributed training
- sync_batchnorm: Synchronize batch norm across devices
- precision: Training precision ('16-mixed', '32', '64', 'bf16-mixed')
- enable_model_summary: Print model summary at training start
- max_epochs: Maximum number of epochs to train
- min_epochs: Minimum number of epochs to train
- max_steps: Maximum number of training steps
- min_steps: Minimum number of training steps
- max_time: Maximum training time
- limit_train_batches: Limit training batches per epoch
- limit_val_batches: Limit validation batches
- limit_test_batches: Limit test batches
- limit_predict_batches: Limit prediction batches
- fast_dev_run: Quick development run with limited batches
- accumulate_grad_batches: Accumulate gradients over N batches
- profiler: Profiler for performance analysis
- benchmark: Enable cuDNN benchmarking for consistent input sizes
- deterministic: Enable deterministic training (may impact performance)
- reload_dataloaders_every_n_epochs: Reload dataloaders periodically
- auto_lr_find: Automatically find optimal learning rate
- replace_sampler_ddp: Replace sampler with DistributedSampler for DDP
- detect_anomaly: Enable anomaly detection for debugging
- auto_scale_batch_size: Automatically scale batch size
- plugins: Additional plugins for custom functionality
- move_metrics_to_cpu: Move metrics to CPU to save GPU memory
- multiple_trainloader_mode: Mode for handling multiple train dataloaders
- inference_mode: Use inference mode during validation/test/predict
- use_distributed_sampler: Use distributed sampler in DDP
- barebones: Minimal trainer setup for maximum performance
"""
def fit(
self,
model: LightningModule,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None
):
"""
Train the model.
Parameters:
- model: LightningModule to train
- train_dataloaders: Training dataloader(s)
- val_dataloaders: Validation dataloader(s)
- datamodule: LightningDataModule containing dataloaders
- ckpt_path: Path to checkpoint to resume training from
"""
def validate(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None
):
"""
Run validation loop.
Parameters:
- model: LightningModule to validate
- dataloaders: Validation dataloader(s)
- ckpt_path: Path to checkpoint to load
- verbose: Print validation results
- datamodule: LightningDataModule containing dataloaders
Returns:
List of validation results
"""
def test(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None
):
"""
Run test loop.
Parameters:
- model: LightningModule to test
- dataloaders: Test dataloader(s)
- ckpt_path: Path to checkpoint to load
- verbose: Print test results
- datamodule: LightningDataModule containing dataloaders
Returns:
List of test results
"""
def predict(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None
):
"""
Run prediction loop.
Parameters:
- model: LightningModule for predictions
- dataloaders: Prediction dataloader(s)
- datamodule: LightningDataModule containing dataloaders
- return_predictions: Return predictions in memory
- ckpt_path: Path to checkpoint to load
Returns:
List of predictions if return_predictions=True
"""
def tune(
self,
model: LightningModule,
train_dataloaders: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[dict] = None,
lr_find_kwargs: Optional[dict] = None
):
"""
Auto-tune model hyperparameters.
Parameters:
- model: LightningModule to tune
- train_dataloaders: Training dataloader(s)
- val_dataloaders: Validation dataloader(s)
- datamodule: LightningDataModule containing dataloaders
- scale_batch_size_kwargs: Arguments for batch size scaling
- lr_find_kwargs: Arguments for learning rate finding
Returns:
Tuning results
"""Base class for organizing PyTorch model code with standardized hooks for training, validation, testing, and prediction. Handles optimizer configuration and provides extensive customization points.
class LightningModule:
def __init__(self):
"""Base class for organizing PyTorch model logic."""
def forward(self, *args, **kwargs):
"""
Define the forward pass of the model.
Returns:
Model predictions
"""
def training_step(self, batch, batch_idx: int):
"""
Define training step logic.
Parameters:
- batch: Training batch data
- batch_idx: Index of the current batch
Returns:
Training loss (torch.Tensor) or dict with 'loss' key
"""
def validation_step(self, batch, batch_idx: int):
"""
Define validation step logic.
Parameters:
- batch: Validation batch data
- batch_idx: Index of the current batch
Returns:
Validation outputs (optional)
"""
def test_step(self, batch, batch_idx: int):
"""
Define test step logic.
Parameters:
- batch: Test batch data
- batch_idx: Index of the current batch
Returns:
Test outputs (optional)
"""
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
"""
Define prediction step logic.
Parameters:
- batch: Prediction batch data
- batch_idx: Index of the current batch
- dataloader_idx: Index of the current dataloader
Returns:
Predictions
"""
def configure_optimizers(self):
"""
Configure optimizers and learning rate schedulers.
Returns:
Optimizer, list of optimizers, or dict with optimizer/scheduler config
"""
def configure_callbacks(self):
"""
Configure model-specific callbacks.
Returns:
List of callback instances
"""
def log(self, name: str, value, prog_bar: bool = False, logger: bool = True,
on_step: bool = None, on_epoch: bool = None, reduce_fx: str = "mean",
enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: str = None,
add_dataloader_idx: bool = True, batch_size: int = None, metric_attribute: str = None,
rank_zero_only: bool = False):
"""
Log metrics during training.
Parameters:
- name: Metric name
- value: Metric value
- prog_bar: Show in progress bar
- logger: Send to logger
- on_step: Log at current step
- on_epoch: Log at epoch end
- reduce_fx: Reduction function for distributed training
- enable_graph: Keep computation graph
- sync_dist: Synchronize across distributed processes
- sync_dist_group: Process group for synchronization
- add_dataloader_idx: Add dataloader index to metric name
- batch_size: Batch size for proper averaging
- metric_attribute: Attribute name for storing metric
- rank_zero_only: Log only on rank 0
"""Base class for organizing data loading logic, providing a clean interface for data preparation, dataset setup, and dataloader creation across different stages of training.
class LightningDataModule:
def __init__(self, *args, **kwargs):
"""Base class for organizing data loading logic."""
def setup(self, stage: str = None):
"""
Setup datasets for different stages.
Parameters:
- stage: Current stage ('fit', 'validate', 'test', 'predict')
"""
def prepare_data(self):
"""
Download and prepare data (called once per node).
Use this for data downloading, tokenization, etc.
"""
def train_dataloader(self):
"""
Create training dataloader.
Returns:
DataLoader for training
"""
def val_dataloader(self):
"""
Create validation dataloader.
Returns:
DataLoader or list of DataLoaders for validation
"""
def test_dataloader(self):
"""
Create test dataloader.
Returns:
DataLoader or list of DataLoaders for testing
"""
def predict_dataloader(self):
"""
Create prediction dataloader.
Returns:
DataLoader or list of DataLoaders for prediction
"""
def teardown(self, stage: str = None):
"""
Clean up after training/testing.
Parameters:
- stage: Current stage ('fit', 'validate', 'test', 'predict')
"""Base class for creating custom training callbacks that can hook into different stages of the training process to extend functionality.
class Callback:
def __init__(self):
"""Base class for creating training callbacks."""
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
"""Called when training begins."""
def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
"""Called when training ends."""
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the beginning of each epoch."""
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the end of each epoch."""
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the beginning of each training epoch."""
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the end of each training epoch."""
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the beginning of each validation epoch."""
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the end of each validation epoch."""
def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the beginning of each test epoch."""
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""Called at the end of each test epoch."""
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
"""Called before each training batch."""
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
"""Called after each training batch."""
def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
"""Called before each validation batch."""
def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
"""Called after each validation batch."""
def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int):
"""Called before each test batch."""
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int):
"""Called after each test batch."""import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, size=1000):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, idx):
return torch.randn(10), torch.randn(1)
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
# Training
model = MyModel()
trainer = L.Trainer(max_epochs=3)
train_loader = DataLoader(MyDataset(), batch_size=32)
trainer.fit(model, train_loader)class MyDataModule(L.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
if stage == 'fit':
self.train_dataset = MyDataset(size=800)
self.val_dataset = MyDataset(size=200)
elif stage == 'test':
self.test_dataset = MyDataset(size=100)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
# Training with DataModule
model = MyModel()
datamodule = MyDataModule(batch_size=64)
trainer = L.Trainer(max_epochs=3)
trainer.fit(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)Install with Tessl CLI
npx tessl i tessl/pypi-pytorch-lightning