CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Pending
Overview
Eval results
Files

core-training.mddocs/

Core Training Components

Essential components for structuring deep learning training workflows in Lightning. These components provide the foundation for organized, scalable, and reproducible machine learning training.

Capabilities

Trainer

The main entry point for Lightning training that orchestrates the entire training process, handling distributed training, logging, checkpointing, and validation automatically.

class Trainer:
    def __init__(
        self,
        accelerator: str = "auto",
        strategy: str = "auto", 
        devices: Union[List[int], str, int] = "auto",
        num_nodes: int = 1,
        precision: Union[str, int] = "32-true",
        logger: Union[Logger, bool] = True,
        callbacks: Optional[List[Callback]] = None,
        fast_dev_run: Union[bool, int] = False,
        max_epochs: Optional[int] = None,
        min_epochs: Optional[int] = None,
        max_steps: int = -1,
        min_steps: Optional[int] = None,
        max_time: Union[str, timedelta] = None,
        limit_train_batches: Union[int, float] = 1.0,
        limit_val_batches: Union[int, float] = 1.0,
        limit_test_batches: Union[int, float] = 1.0,
        limit_predict_batches: Union[int, float] = 1.0,
        overfit_batches: Union[int, float] = 0.0,
        val_check_interval: Union[int, float] = 1.0,
        check_val_every_n_epoch: Optional[int] = 1,
        num_sanity_val_steps: int = 2,
        log_every_n_steps: int = 50,
        enable_checkpointing: bool = True,
        enable_progress_bar: bool = True,
        enable_model_summary: bool = True,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Optional[str] = None,
        deterministic: Optional[bool] = None,
        benchmark: Optional[bool] = None,
        inference_mode: bool = True,
        use_distributed_sampler: bool = True,
        profiler: Optional[Profiler] = None,
        detect_anomaly: bool = False,
        barebones: bool = False,
        plugins: Optional[List[Any]] = None,
        sync_batchnorm: bool = False,
        reload_dataloaders_every_n_epochs: int = 0,
        default_root_dir: Optional[str] = None,
        **kwargs
    ):
        """
        Initialize the Lightning Trainer.
        
        Args:
            accelerator: Hardware accelerator type ('cpu', 'gpu', 'tpu', 'auto')
            strategy: Distributed training strategy ('ddp', 'fsdp', 'deepspeed', etc.)
            devices: Which devices to use for training
            num_nodes: Number of nodes for distributed training
            precision: Precision mode ('32-true', '16-mixed', 'bf16-mixed', etc.)
            logger: Logger instance or True/False to enable/disable default logger
            callbacks: List of callbacks to use during training
            fast_dev_run: Run a single batch for debugging
            max_epochs: Maximum number of epochs to train
            min_epochs: Minimum number of epochs to train
            max_steps: Maximum number of training steps
            min_steps: Minimum number of training steps
            max_time: Maximum training time
            limit_train_batches: Limit training batches per epoch
            limit_val_batches: Limit validation batches per epoch
            val_check_interval: How often to check validation
            enable_checkpointing: Enable automatic checkpointing
            enable_progress_bar: Show progress bar during training
            accumulate_grad_batches: Gradient accumulation steps
            gradient_clip_val: Gradient clipping value
            deterministic: Make training deterministic
            profiler: Profiler for performance analysis
        """

    def fit(
        self,
        model: LightningModule,
        train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        ckpt_path: Optional[str] = None
    ) -> None:
        """
        Fit the model with training and validation data.
        
        Args:
            model: LightningModule to train
            train_dataloaders: Training data loaders
            val_dataloaders: Validation data loaders  
            datamodule: LightningDataModule containing data loaders
            ckpt_path: Path to checkpoint to resume from
        """

    def validate(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        ckpt_path: Optional[str] = None,
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None
    ) -> List[Dict[str, float]]:
        """
        Run validation on the model.
        
        Args:
            model: LightningModule to validate
            dataloaders: Validation data loaders
            ckpt_path: Path to checkpoint to load
            verbose: Print validation results
            datamodule: LightningDataModule containing data loaders
            
        Returns:
            List of validation metrics dictionaries
        """

    def test(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        ckpt_path: Optional[str] = None,
        verbose: bool = True,
        datamodule: Optional[LightningDataModule] = None
    ) -> List[Dict[str, float]]:
        """
        Run testing on the model.
        
        Args:
            model: LightningModule to test
            dataloaders: Test data loaders
            ckpt_path: Path to checkpoint to load
            verbose: Print test results
            datamodule: LightningDataModule containing data loaders
            
        Returns:
            List of test metrics dictionaries
        """

    def predict(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        return_predictions: Optional[bool] = None,
        ckpt_path: Optional[str] = None
    ) -> Optional[List[Any]]:
        """
        Run prediction on the model.
        
        Args:
            model: LightningModule to use for prediction
            dataloaders: Prediction data loaders
            datamodule: LightningDataModule containing data loaders
            return_predictions: Whether to return predictions
            ckpt_path: Path to checkpoint to load
            
        Returns:
            List of predictions if return_predictions=True
        """

    def tune(
        self,
        model: LightningModule,
        train_dataloaders: Optional[TRAIN_DATALOADERS] = None,
        val_dataloaders: Optional[EVAL_DATALOADERS] = None,
        datamodule: Optional[LightningDataModule] = None,
        scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
        lr_find_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Tune hyperparameters for the model.
        
        Args:
            model: LightningModule to tune
            train_dataloaders: Training data loaders
            val_dataloaders: Validation data loaders
            datamodule: LightningDataModule containing data loaders
            scale_batch_size_kwargs: Arguments for batch size scaling
            lr_find_kwargs: Arguments for learning rate finding
            
        Returns:
            Dictionary with tuning results
        """

LightningModule

Base class for organizing PyTorch code in Lightning. Defines model architecture, training logic, optimization, and provides hooks for the training lifecycle.

class LightningModule(nn.Module):
    def __init__(self):
        """Initialize the LightningModule."""
        super().__init__()

    def forward(self, *args, **kwargs) -> Any:
        """
        Define the forward pass of the model.
        
        Args:
            *args: Positional arguments
            **kwargs: Keyword arguments
            
        Returns:
            Model output
        """

    def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
        """
        Define a single training step.
        
        Args:
            batch: Batch of training data
            batch_idx: Index of the current batch
            
        Returns:
            Loss tensor or dictionary with 'loss' key
        """

    def validation_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
        """
        Define a single validation step.
        
        Args:
            batch: Batch of validation data
            batch_idx: Index of the current batch
            
        Returns:
            Optional loss tensor or metrics dictionary
        """

    def test_step(self, batch: Any, batch_idx: int) -> Optional[STEP_OUTPUT]:
        """
        Define a single test step.
        
        Args:
            batch: Batch of test data
            batch_idx: Index of the current batch
            
        Returns:
            Optional loss tensor or metrics dictionary
        """

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        """
        Define a single prediction step.
        
        Args:
            batch: Batch of prediction data
            batch_idx: Index of the current batch
            dataloader_idx: Index of the dataloader
            
        Returns:
            Model predictions
        """

    def configure_optimizers(self) -> Union[Optimizer, Dict[str, Any]]:
        """
        Configure optimizers and learning rate schedulers.
        
        Returns:
            Optimizer or dictionary with optimizer/scheduler configuration
        """

    def configure_callbacks(self) -> Union[List[Callback], Callback]:
        """
        Configure callbacks for this model.
        
        Returns:
            List of callbacks or single callback
        """

    def log(
        self,
        name: str,
        value: Any,
        prog_bar: bool = False,
        logger: bool = True,
        on_step: Optional[bool] = None,
        on_epoch: Optional[bool] = None,
        reduce_fx: str = "mean",
        enable_graph: bool = False,
        sync_dist: bool = False,
        sync_dist_group: Optional[Any] = None,
        add_dataloader_idx: bool = True,
        batch_size: Optional[int] = None,
        metric_attribute: Optional[str] = None,
        rank_zero_only: bool = False
    ) -> None:
        """
        Log a key-value pair.
        
        Args:
            name: Name of the metric
            value: Value to log
            prog_bar: Show in progress bar
            logger: Send to logger
            on_step: Log at each step
            on_epoch: Log at each epoch
            reduce_fx: Reduction function for distributed training
            sync_dist: Synchronize across distributed processes
            batch_size: Current batch size for proper reduction
        """

    def log_dict(
        self,
        dictionary: Dict[str, Any],
        prog_bar: bool = False,
        logger: bool = True,
        on_step: Optional[bool] = None,
        on_epoch: Optional[bool] = None,
        reduce_fx: str = "mean",
        enable_graph: bool = False,
        sync_dist: bool = False,
        sync_dist_group: Optional[Any] = None,
        add_dataloader_idx: bool = True,
        batch_size: Optional[int] = None,
        rank_zero_only: bool = False
    ) -> None:
        """
        Log a dictionary of key-value pairs.
        
        Args:
            dictionary: Dictionary of metrics to log
            prog_bar: Show in progress bar
            logger: Send to logger
            on_step: Log at each step
            on_epoch: Log at each epoch
            reduce_fx: Reduction function for distributed training
            sync_dist: Synchronize across distributed processes
            batch_size: Current batch size for proper reduction
        """

LightningDataModule

Encapsulates data loading logic including data downloading, preparation, splitting, and data loader creation. Provides a clean interface for data handling across train/val/test splits.

class LightningDataModule:
    def __init__(self):
        """Initialize the LightningDataModule."""

    def prepare_data(self) -> None:
        """
        Download and prepare data. Called only on rank 0.
        Use this for data download, preprocessing that shouldn't be done on every device.
        """

    def setup(self, stage: str) -> None:
        """
        Set up datasets for each stage.
        
        Args:
            stage: 'fit', 'validate', 'test', or 'predict'
        """

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        """
        Create training data loader.
        
        Returns:
            Training data loader(s)
        """

    def val_dataloader(self) -> EVAL_DATALOADERS:
        """
        Create validation data loader.
        
        Returns:
            Validation data loader(s)
        """

    def test_dataloader(self) -> EVAL_DATALOADERS:
        """
        Create test data loader.
        
        Returns:
            Test data loader(s)
        """

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        """
        Create prediction data loader.
        
        Returns:
            Prediction data loader(s)
        """

    def teardown(self, stage: str) -> None:
        """
        Clean up after training/testing.
        
        Args:
            stage: 'fit', 'validate', 'test', or 'predict'
        """

    def state_dict(self) -> Dict[str, Any]:
        """
        Called when saving a checkpoint.
        
        Returns:
            Dictionary of state to save
        """

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """
        Called when loading a checkpoint.
        
        Args:
            state_dict: Dictionary of saved state
        """

Callback

Base class for creating custom callbacks to hook into the training lifecycle. Callbacks provide a way to add functionality at specific points during training.

class Callback:
    def __init__(self):
        """Initialize the callback."""

    def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when training begins."""

    def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when training ends."""

    def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when validation begins."""

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when validation ends."""

    def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when testing begins."""

    def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when testing ends."""

    def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when an epoch begins."""

    def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when an epoch ends."""

    def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when a training epoch begins."""

    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when a training epoch ends."""

    def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when a validation epoch begins."""

    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        """Called when a validation epoch ends."""

    def on_train_batch_start(
        self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
    ) -> None:
        """Called when a training batch begins."""

    def on_train_batch_end(
        self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int
    ) -> None:
        """Called when a training batch ends."""

    def on_validation_batch_start(
        self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        """Called when a validation batch begins."""

    def on_validation_batch_end(
        self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        """Called when a validation batch ends."""

    def on_before_optimizer_step(
        self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int
    ) -> None:
        """Called before optimizer step."""

    def on_before_zero_grad(
        self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer
    ) -> None:
        """Called before gradients are zeroed."""

    def on_save_checkpoint(
        self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
    ) -> None:
        """Called when saving a checkpoint."""

    def on_load_checkpoint(
        self, trainer: Trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]
    ) -> None:
        """Called when loading a checkpoint."""

    def state_dict(self) -> Dict[str, Any]:
        """
        Called when saving a checkpoint.
        
        Returns:
            Dictionary of callback state to save
        """

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """
        Called when loading a checkpoint.
        
        Args:
            state_dict: Dictionary of saved callback state
        """

Install with Tessl CLI

npx tessl i tessl/pypi-lightning

docs

accelerators.md

callbacks.md

core-training.md

data.md

fabric.md

index.md

loggers.md

precision.md

profilers.md

strategies.md

tile.json