The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
—
Essential components for structuring deep learning training workflows in Lightning. These components provide the foundation for organized, scalable, and reproducible machine learning training.
The main entry point for Lightning training that orchestrates the entire training process, handling distributed training, logging, checkpointing, and validation automatically.
class Trainer:
def __init__(
self,
accelerator: str = "auto",
strategy: str = "auto",
devices: Union[List[int], str, int] = "auto",
num_nodes: int = 1,
precision: Union[str, int] = "32-true",
logger: Union[Logger, bool] = True,
callbacks: Optional[List[Callback]] = None,
fast_dev_run: Union[bool, int] = False,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: int = -1,
min_steps: Optional[int] = None,
max_time: Union[str, timedelta] = None,
limit_train_batches: Union[int, float] = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
limit_predict_batches: Union[int, float] = 1.0,
overfit_batches: Union[int, float] = 0.0,
val_check_interval: Union[int, float] = 1.0,
check_val_every_n_epoch: Optional[int] = 1,
num_sanity_val_steps: int = 2,
log_every_n_steps: int = 50,
enable_checkpointing: bool = True,
enable_progress_bar: bool = True,
enable_model_summary: bool = True,
accumulate_grad_batches: int = 1,
gradient_clip_val: Optional[float] = None,
gradient_clip_algorithm: Optional[str] = None,
deterministic: Optional[bool] = None,
benchmark: Optional[bool] = None,
inference_mode: bool = True,
use_distributed_sampler: bool = True,
profiler: Optional[Profiler] = None,
detect_anomaly: bool = False,
barebones: bool = False,
plugins: Optional[List[Any]] = None,
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[str] = None,
**kwargs
):
"""
Initialize the Lightning Trainer.
Args:
accelerator: Hardware accelerator type ('cpu', 'gpu', 'tpu', 'auto')
strategy: Distributed training strategy ('ddp', 'fsdp', 'deepspeed', etc.)
devices: Which devices to use for training
num_nodes: Number of nodes for distributed training
precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
logger: Logger instance or True/False to enable/disable default logger
callbacks: List of callbacks to use during training
fast_dev_run: Run a single batch for debugging
max_epochs: Maximum number of epochs to train
min_epochs: Minimum number of epochs to train
max_steps: Maximum number of training steps
min_steps: Minimum number of training steps
max_time: Maximum training time
limit_train_batches: Limit training batches per epoch
limit_val_batches: Limit validation batches per epoch
val_check_interval: How often to check validation
enable_checkpointing: Enable automatic checkpointing
enable_progress_bar: Show progress bar during training
accumulate_grad_batches: Gradient accumulation steps
gradient_clip_val: Gradient clipping value
deterministic: Make training deterministic
profiler: Profiler for performance analysis
"""
def fit(
self,
model: LightningModule,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None
) -> None:
"""
Fit the model with training and validation data.
Args:
model: LightningModule to train
train_dataloaders: Training data loaders
val_dataloaders: Validation data loaders
datamodule: LightningDataModule containing data loaders
ckpt_path: Path to checkpoint to resume from
"""
def validate(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None
) -> List[Dict[str, float]]:
"""
Run validation on the model.
Args:
model: LightningModule to validate
dataloaders: Validation data loaders
ckpt_path: Path to checkpoint to load
verbose: Print validation results
datamodule: LightningDataModule containing data loaders
Returns:
List of validation metrics dictionaries
"""
def test(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None
) -> List[Dict[str, float]]:
"""
Run testing on the model.
Args:
model: LightningModule to test
dataloaders: Test data loaders
ckpt_path: Path to checkpoint to load
verbose: Print test results
datamodule: LightningDataModule containing data loaders
Returns:
List of test metrics dictionaries
"""
def predict(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None
) -> Optional[List[Any]]:
"""
Run prediction on the model.
Args:
model: LightningModule to use for prediction
dataloaders: Prediction data loaders
datamodule: LightningDataModule containing data loaders
return_predictions: Whether to return predictions
ckpt_path: Path to checkpoint to load
Returns:
List of predictions if return_predictions=True
"""
def tune(
self,
model: LightningModule,
train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Tune hyperparameters for the model.
Args:
model: LightningModule to tune
train_dataloaders: Training data loaders
val_dataloaders: Validation data loaders
datamodule: LightningDataModule containing data loaders
scale_batch_size_kwargs: Arguments for batch size scaling
lr_find_kwargs: Arguments for learning rate finding
Returns:
Dictionary with tuning results
"""Base class for organizing PyTorch code in Lightning. Defines model architecture, training logic, optimization, and provides hooks for the training lifecycle.
class LightningModule(nn.Module):
def __init__(self):
"""Initialize the LightningModule."""
super().__init__()
def forward(self, *args, **kwargs) -> Any:
"""
Define the forward pass of the model.
Args:
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Model output
"""
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
"""
Define a single training step.
Args:
batch: Batch of training data
batch_idx: Index of the current batch
Returns:
Loss tensor or dictionary with 'loss' key
"""
def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
"""
Define a single validation step.
Args:
batch: Batch of validation data
batch_idx: Index of the current batch
Returns:
Optional loss tensor or metrics dictionary
"""
def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
"""
Define a single test step.
Args:
batch: Batch of test data
batch_idx: Index of the current batch
Returns:
Optional loss tensor or metrics dictionary
"""
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""
Define a single prediction step.
Args:
batch: Batch of prediction data
batch_idx: Index of the current batch
dataloader_idx: Index of the dataloader
Returns:
Model predictions
"""
def configure_optimizers(self) -> Union[Optimizer, Dict[str, Any]]:
"""
Configure optimizers and learning rate schedulers.
Returns:
Optimizer or dictionary with optimizer/scheduler configuration
"""
def configure_callbacks(self) -> Union[List[Callback], Callback]:
"""
Configure callbacks for this model.
Returns:
List of callbacks or single callback
"""
def log(
self,
name: str,
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: str = "mean",
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: bool = False
) -> None:
"""
Log a key-value pair.
Args:
name: Name of the metric
value: Value to log
prog_bar: Show in progress bar
logger: Send to logger
on_step: Log at each step
on_epoch: Log at each epoch
reduce_fx: Reduction function for distributed training
sync_dist: Synchronize across distributed processes
batch_size: Current batch size for proper reduction
"""
def log_dict(
self,
dictionary: Dict[str, Any],
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
on_epoch: Optional[bool] = None,
reduce_fx: str = "mean",
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: bool = False
) -> None:
"""
Log a dictionary of key-value pairs.
Args:
dictionary: Dictionary of metrics to log
prog_bar: Show in progress bar
logger: Send to logger
on_step: Log at each step
on_epoch: Log at each epoch
reduce_fx: Reduction function for distributed training
sync_dist: Synchronize across distributed processes
batch_size: Current batch size for proper reduction
"""Encapsulates data loading logic including data downloading, preparation, splitting, and data loader creation. Provides a clean interface for data handling across train/val/test splits.
class LightningDataModule:
def __init__(self):
"""Initialize the LightningDataModule."""
def prepare_data(self) -> None:
"""
Download and prepare data. Called only on rank 0.
Use this for data download, preprocessing that shouldn't be done on every device.
"""
def setup(self, stage: str) -> None:
"""
Set up datasets for each stage.
Args:
stage: 'fit', 'validate', 'test', or 'predict'
"""
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""
Create training data loader.
Returns:
Training data loader(s)
"""
def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Create validation data loader.
Returns:
Validation data loader(s)
"""
def test_dataloader(self) -> EVAL_DATALOADERS:
"""
Create test data loader.
Returns:
Test data loader(s)
"""
def predict_dataloader(self) -> EVAL_DATALOADERS:
"""
Create prediction data loader.
Returns:
Prediction data loader(s)
"""
def teardown(self, stage: str) -> None:
"""
Clean up after training/testing.
Args:
stage: 'fit', 'validate', 'test', or 'predict'
"""
def state_dict(self) -> Dict[str, Any]:
"""
Called when saving a checkpoint.
Returns:
Dictionary of state to save
"""
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Called when loading a checkpoint.
Args:
state_dict: Dictionary of saved state
"""Base class for creating custom callbacks to hook into the training lifecycle. Callbacks provide a way to add functionality at specific points during training.
class Callback:
def __init__(self):
"""Initialize the callback."""
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when training begins."""
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when training ends."""
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when validation begins."""
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when validation ends."""
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when testing begins."""
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when testing ends."""
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when an epoch begins."""
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when an epoch ends."""
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when a training epoch begins."""
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when a training epoch ends."""
def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when a validation epoch begins."""
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when a validation epoch ends."""
def on_train_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
"""Called when a training batch begins."""
def on_train_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
"""Called when a training batch ends."""
def on_validation_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Called when a validation batch begins."""
def on_validation_batch_end(
self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Called when a validation batch ends."""
def on_before_optimizer_step(
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before optimizer step."""
def on_before_zero_grad(
self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
) -> None:
"""Called before gradients are zeroed."""
def on_save_checkpoint(
self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
) -> None:
"""Called when saving a checkpoint."""
def on_load_checkpoint(
self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
) -> None:
"""Called when loading a checkpoint."""
def state_dict(self) -> Dict[str, Any]:
"""
Called when saving a checkpoint.
Returns:
Dictionary of callback state to save
"""
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Called when loading a checkpoint.
Args:
state_dict: Dictionary of saved callback state
"""Install with Tessl CLI
npx tessl i tessl/pypi-lightning