CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-lightning

The Deep Learning framework to train, deploy, and ship AI products Lightning fast.

Pending
Overview
Eval results
Files

callbacks.mddocs/

Callbacks and Lifecycle Hooks

Comprehensive callback system for training lifecycle management including checkpointing, early stopping, learning rate scheduling, monitoring, and optimization callbacks. Callbacks provide a clean way to add functionality without modifying the core training loop.

Capabilities

Model Checkpointing

Automatically save model checkpoints during training based on monitored metrics, with support for saving top-k models and automatic cleanup.

class ModelCheckpoint(Callback):
    def __init__(
        self,
        dirpath: Optional[str] = None,
        filename: Optional[str] = None,
        monitor: Optional[str] = None,
        verbose: bool = False,
        save_last: Optional[bool] = None,
        save_top_k: int = 1,
        save_weights_only: bool = False,
        mode: str = "min",
        auto_insert_metric_name: bool = True,
        every_n_train_steps: Optional[int] = None,
        train_time_interval: Optional[timedelta] = None,
        every_n_epochs: Optional[int] = None,
        save_on_train_epoch_end: Optional[bool] = None,
        enable_version_counter: bool = True
    ):
        """
        Initialize ModelCheckpoint callback.
        
        Args:
            dirpath: Directory to save checkpoints
            filename: Checkpoint filename pattern
            monitor: Metric to monitor for saving best models
            verbose: Print checkpoint saving messages
            save_last: Always save the last checkpoint
            save_top_k: Number of best models to save
            save_weights_only: Save only model weights
            mode: 'min' or 'max' for monitored metric
            auto_insert_metric_name: Insert metric name in filename
            every_n_train_steps: Save every N training steps
            train_time_interval: Save every time interval
            every_n_epochs: Save every N epochs
            save_on_train_epoch_end: Save at end of training epoch
            enable_version_counter: Enable version counter in filename
        """

    @property
    def best_model_path(self) -> str:
        """Path to the best saved model."""

    @property
    def best_model_score(self) -> Optional[float]:
        """Score of the best saved model."""

    @property
    def last_model_path(self) -> str:
        """Path to the last saved model."""

Early Stopping

Stop training when a monitored metric stops improving, with configurable patience and thresholds to prevent overfitting.

class EarlyStopping(Callback):
    def __init__(
        self,
        monitor: str,
        min_delta: float = 0.0,
        patience: int = 3,
        verbose: bool = False,
        mode: str = "min",
        strict: bool = True,
        check_finite: bool = True,
        stopping_threshold: Optional[float] = None,
        divergence_threshold: Optional[float] = None,
        check_on_train_epoch_end: Optional[bool] = None,
        log_rank_zero_only: bool = False
    ):
        """
        Initialize EarlyStopping callback.
        
        Args:
            monitor: Metric to monitor
            min_delta: Minimum change to qualify as improvement
            patience: Number of epochs with no improvement to wait
            verbose: Print early stopping messages
            mode: 'min' or 'max' for monitored metric
            strict: Raise error if monitored metric is not found
            check_finite: Stop if monitored metric is not finite
            stopping_threshold: Stop when metric reaches this threshold
            divergence_threshold: Stop if metric diverges beyond this
            check_on_train_epoch_end: Check metric at end of training epoch
            log_rank_zero_only: Log only on rank 0
        """

    @property
    def wait_count(self) -> int:
        """Number of epochs waited since last improvement."""

    @property
    def best_score(self) -> Optional[float]:
        """Best score achieved."""

    @property
    def stopped_epoch(self) -> int:
        """Epoch when training was stopped."""

Learning Rate Monitoring

Monitor and log learning rate changes during training, supporting multiple optimizers and schedulers.

class LearningRateMonitor(Callback):
    def __init__(
        self,
        logging_interval: str = "epoch",
        log_momentum: bool = False,
        log_weight_decay: bool = False
    ):
        """
        Initialize LearningRateMonitor callback.
        
        Args:
            logging_interval: 'step' or 'epoch' for logging frequency  
            log_momentum: Also log momentum values
            log_weight_decay: Also log weight decay values
        """

Stochastic Weight Averaging

Implement stochastic weight averaging to improve model generalization by averaging weights from multiple epochs.

class StochasticWeightAveraging(Callback):
    def __init__(
        self,
        swa_lrs: Union[float, List[float]],
        swa_epoch_start: Union[int, float] = 0.8,
        annealing_epochs: int = 10,
        annealing_strategy: str = "cos",
        avg_fn: Optional[Callable] = None,
        device: Optional[Union[torch.device, str]] = None
    ):
        """
        Initialize StochasticWeightAveraging callback.
        
        Args:
            swa_lrs: Learning rate(s) for SWA
            swa_epoch_start: Epoch to start SWA (int or fraction)
            annealing_epochs: Number of epochs for annealing
            annealing_strategy: 'linear' or 'cos' annealing
            avg_fn: Custom averaging function
            device: Device for SWA model
        """

Progress Bars

Visual progress indicators during training with customizable display options and rich formatting support.

class TQDMProgressBar(Callback):
    def __init__(
        self,
        refresh_rate: int = 1,
        process_position: int = 0
    ):
        """
        Initialize TQDM progress bar.
        
        Args:
            refresh_rate: Progress bar refresh rate
            process_position: Position for multiple progress bars
        """

class RichProgressBar(Callback):
    def __init__(
        self,
        refresh_rate: int = 1,
        leave: bool = False,
        theme: RichProgressBarTheme = RichProgressBarTheme(),
        console_kwargs: Optional[Dict[str, Any]] = None
    ):
        """
        Initialize Rich progress bar with enhanced formatting.
        
        Args:
            refresh_rate: Progress bar refresh rate
            leave: Keep progress bar after completion
            theme: Rich theme configuration
            console_kwargs: Additional console arguments
        """

class ProgressBar(Callback):
    def __init__(self):
        """Base progress bar callback."""

    def disable(self) -> None:
        """Disable the progress bar."""

    def enable(self) -> None:
        """Enable the progress bar."""

Model Summary Display

Display detailed model architecture information including layer types, parameters, and memory usage.

class ModelSummary(Callback):
    def __init__(self, max_depth: int = 1):
        """
        Initialize ModelSummary callback.
        
        Args:
            max_depth: Maximum depth for nested modules
        """

class RichModelSummary(Callback):
    def __init__(self, max_depth: int = 1):
        """
        Initialize RichModelSummary with enhanced formatting.
        
        Args:
            max_depth: Maximum depth for nested modules
        """

Hyperparameter Optimization

Callbacks for automated hyperparameter tuning including batch size finding and learning rate finding.

class BatchSizeFinder(Callback):
    def __init__(
        self,
        mode: str = "power",
        steps_per_trial: int = 3,
        init_val: int = 2,
        max_trials: int = 25,
        batch_arg_name: str = "batch_size"
    ):
        """
        Initialize BatchSizeFinder callback.
        
        Args:
            mode: 'power' or 'binsearch' for search strategy
            steps_per_trial: Steps per batch size trial
            init_val: Initial batch size
            max_trials: Maximum number of trials
            batch_arg_name: Argument name for batch size
        """

class LearningRateFinder(Callback):
    def __init__(
        self,
        min_lr: float = 1e-8,
        max_lr: float = 1.0,
        num_training: int = 100,
        mode: str = "exponential",
        early_stop_threshold: float = 4.0,
        update_attr: bool = False
    ):
        """
        Initialize LearningRateFinder callback.
        
        Args:
            min_lr: Minimum learning rate
            max_lr: Maximum learning rate  
            num_training: Number of training steps
            mode: 'exponential' or 'linear' search
            early_stop_threshold: Threshold for early stopping  
            update_attr: Update model's learning rate attribute
        """

Fine-tuning Callbacks

Specialized callbacks for transfer learning and progressive fine-tuning strategies.

class BaseFinetuning(Callback):
    def __init__(self, unfreeze_at_epoch: int = 10, lambda_func: Optional[Callable] = None):
        """
        Base class for fine-tuning callbacks.
        
        Args:
            unfreeze_at_epoch: Epoch to unfreeze parameters
            lambda_func: Function to determine learning rates
        """

    def freeze_before_training(self, pl_module: LightningModule) -> None:
        """Freeze parameters before training starts."""

    def finetune_function(
        self,
        pl_module: LightningModule,
        current_epoch: int,
        optimizer: Optimizer,
        optimizer_idx: int
    ) -> None:
        """Function called during fine-tuning."""

class BackboneFinetuning(BaseFinetuning):
    def __init__(
        self,
        unfreeze_backbone_at_epoch: int = 10,
        lambda_func: Optional[Callable] = None,
        backbone_initial_ratio_lr: float = 0.1,
        backbone_initial_lr: Optional[float] = None,
        should_align: bool = True,
        initial_denom_lr: float = 10.0,
        train_bn: bool = True
    ):
        """
        Fine-tuning callback for backbone networks.
        
        Args:
            unfreeze_backbone_at_epoch: Epoch to unfreeze backbone
            lambda_func: Learning rate scheduling function
            backbone_initial_ratio_lr: Initial backbone LR ratio
            backbone_initial_lr: Initial backbone learning rate
            should_align: Align learning rates
            initial_denom_lr: Initial denominator for LR calculation
            train_bn: Train batch normalization layers
        """

Performance Monitoring

Callbacks for monitoring training performance, throughput, and resource utilization.

class ThroughputMonitor(Callback):
    def __init__(
        self,
        length_key: str = "seq_len",
        batch_size_key: str = "batch_size",
        window_size: int = 100
    ):
        """
        Initialize ThroughputMonitor callback.
        
        Args:
            length_key: Key for sequence length in batch
            batch_size_key: Key for batch size
            window_size: Window size for throughput calculation
        """

class DeviceStatsMonitor(Callback):
    def __init__(self, cpu_stats: Optional[bool] = None):
        """
        Initialize DeviceStatsMonitor callback.
        
        Args:
            cpu_stats: Monitor CPU statistics
        """

class Timer(Callback):
    def __init__(self, duration: Optional[Union[str, timedelta]] = None, interval: str = "step"):
        """
        Initialize Timer callback for training duration control.
        
        Args:
            duration: Maximum training duration
            interval: 'step' or 'epoch' for timing
        """

Custom Callback Creation

class LambdaCallback(Callback):
    def __init__(self, **kwargs):
        """
        Create callback from lambda functions.
        
        Args:
            **kwargs: Mapping of hook names to functions
        """

Usage Examples

Basic Callback Usage

from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

# Configure callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints',
    filename='model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min'
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)

# Use callbacks in trainer
trainer = Trainer(
    callbacks=[checkpoint_callback, early_stopping],
    max_epochs=100
)

Custom Callback Example

import lightning as L

class MetricLoggingCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # Log custom metrics at end of each epoch
        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch
        
        # Custom logging logic
        if 'train_loss' in metrics:
            print(f"Epoch {epoch}: Train Loss = {metrics['train_loss']:.4f}")
        
        # Save metrics to file
        with open('metrics.log', 'a') as f:
            f.write(f"Epoch {epoch}: {dict(metrics)}\n")

# Use custom callback
trainer = Trainer(callbacks=[MetricLoggingCallback()])

Install with Tessl CLI

npx tessl i tessl/pypi-lightning

docs

accelerators.md

callbacks.md

core-training.md

data.md

fabric.md

index.md

loggers.md

precision.md

profilers.md

strategies.md

tile.json