The Deep Learning framework to train, deploy, and ship AI products Lightning fast.
—
Comprehensive callback system for training lifecycle management including checkpointing, early stopping, learning rate scheduling, monitoring, and optimization callbacks. Callbacks provide a clean way to add functionality without modifying the core training loop.
Automatically save model checkpoints during training based on monitored metrics, with support for saving top-k models and automatic cleanup.
class ModelCheckpoint(Callback):
def __init__(
self,
dirpath: Optional[str] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
save_last: Optional[bool] = None,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
enable_version_counter: bool = True
):
"""
Initialize ModelCheckpoint callback.
Args:
dirpath: Directory to save checkpoints
filename: Checkpoint filename pattern
monitor: Metric to monitor for saving best models
verbose: Print checkpoint saving messages
save_last: Always save the last checkpoint
save_top_k: Number of best models to save
save_weights_only: Save only model weights
mode: 'min' or 'max' for monitored metric
auto_insert_metric_name: Insert metric name in filename
every_n_train_steps: Save every N training steps
train_time_interval: Save every time interval
every_n_epochs: Save every N epochs
save_on_train_epoch_end: Save at end of training epoch
enable_version_counter: Enable version counter in filename
"""
@property
def best_model_path(self) -> str:
"""Path to the best saved model."""
@property
def best_model_score(self) -> Optional[float]:
"""Score of the best saved model."""
@property
def last_model_path(self) -> str:
"""Path to the last saved model."""Stop training when a monitored metric stops improving, with configurable patience and thresholds to prevent overfitting.
class EarlyStopping(Callback):
def __init__(
self,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = "min",
strict: bool = True,
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: Optional[bool] = None,
log_rank_zero_only: bool = False
):
"""
Initialize EarlyStopping callback.
Args:
monitor: Metric to monitor
min_delta: Minimum change to qualify as improvement
patience: Number of epochs with no improvement to wait
verbose: Print early stopping messages
mode: 'min' or 'max' for monitored metric
strict: Raise error if monitored metric is not found
check_finite: Stop if monitored metric is not finite
stopping_threshold: Stop when metric reaches this threshold
divergence_threshold: Stop if metric diverges beyond this
check_on_train_epoch_end: Check metric at end of training epoch
log_rank_zero_only: Log only on rank 0
"""
@property
def wait_count(self) -> int:
"""Number of epochs waited since last improvement."""
@property
def best_score(self) -> Optional[float]:
"""Best score achieved."""
@property
def stopped_epoch(self) -> int:
"""Epoch when training was stopped."""Monitor and log learning rate changes during training, supporting multiple optimizers and schedulers.
class LearningRateMonitor(Callback):
def __init__(
self,
logging_interval: str = "epoch",
log_momentum: bool = False,
log_weight_decay: bool = False
):
"""
Initialize LearningRateMonitor callback.
Args:
logging_interval: 'step' or 'epoch' for logging frequency
log_momentum: Also log momentum values
log_weight_decay: Also log weight decay values
"""Implement stochastic weight averaging to improve model generalization by averaging weights from multiple epochs.
class StochasticWeightAveraging(Callback):
def __init__(
self,
swa_lrs: Union[float, List[float]],
swa_epoch_start: Union[int, float] = 0.8,
annealing_epochs: int = 10,
annealing_strategy: str = "cos",
avg_fn: Optional[Callable] = None,
device: Optional[Union[torch.device, str]] = None
):
"""
Initialize StochasticWeightAveraging callback.
Args:
swa_lrs: Learning rate(s) for SWA
swa_epoch_start: Epoch to start SWA (int or fraction)
annealing_epochs: Number of epochs for annealing
annealing_strategy: 'linear' or 'cos' annealing
avg_fn: Custom averaging function
device: Device for SWA model
"""Visual progress indicators during training with customizable display options and rich formatting support.
class TQDMProgressBar(Callback):
def __init__(
self,
refresh_rate: int = 1,
process_position: int = 0
):
"""
Initialize TQDM progress bar.
Args:
refresh_rate: Progress bar refresh rate
process_position: Position for multiple progress bars
"""
class RichProgressBar(Callback):
def __init__(
self,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
console_kwargs: Optional[Dict[str, Any]] = None
):
"""
Initialize Rich progress bar with enhanced formatting.
Args:
refresh_rate: Progress bar refresh rate
leave: Keep progress bar after completion
theme: Rich theme configuration
console_kwargs: Additional console arguments
"""
class ProgressBar(Callback):
def __init__(self):
"""Base progress bar callback."""
def disable(self) -> None:
"""Disable the progress bar."""
def enable(self) -> None:
"""Enable the progress bar."""Display detailed model architecture information including layer types, parameters, and memory usage.
class ModelSummary(Callback):
def __init__(self, max_depth: int = 1):
"""
Initialize ModelSummary callback.
Args:
max_depth: Maximum depth for nested modules
"""
class RichModelSummary(Callback):
def __init__(self, max_depth: int = 1):
"""
Initialize RichModelSummary with enhanced formatting.
Args:
max_depth: Maximum depth for nested modules
"""Callbacks for automated hyperparameter tuning including batch size finding and learning rate finding.
class BatchSizeFinder(Callback):
def __init__(
self,
mode: str = "power",
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = "batch_size"
):
"""
Initialize BatchSizeFinder callback.
Args:
mode: 'power' or 'binsearch' for search strategy
steps_per_trial: Steps per batch size trial
init_val: Initial batch size
max_trials: Maximum number of trials
batch_arg_name: Argument name for batch size
"""
class LearningRateFinder(Callback):
def __init__(
self,
min_lr: float = 1e-8,
max_lr: float = 1.0,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: float = 4.0,
update_attr: bool = False
):
"""
Initialize LearningRateFinder callback.
Args:
min_lr: Minimum learning rate
max_lr: Maximum learning rate
num_training: Number of training steps
mode: 'exponential' or 'linear' search
early_stop_threshold: Threshold for early stopping
update_attr: Update model's learning rate attribute
"""Specialized callbacks for transfer learning and progressive fine-tuning strategies.
class BaseFinetuning(Callback):
def __init__(self, unfreeze_at_epoch: int = 10, lambda_func: Optional[Callable] = None):
"""
Base class for fine-tuning callbacks.
Args:
unfreeze_at_epoch: Epoch to unfreeze parameters
lambda_func: Function to determine learning rates
"""
def freeze_before_training(self, pl_module: LightningModule) -> None:
"""Freeze parameters before training starts."""
def finetune_function(
self,
pl_module: LightningModule,
current_epoch: int,
optimizer: Optimizer,
optimizer_idx: int
) -> None:
"""Function called during fine-tuning."""
class BackboneFinetuning(BaseFinetuning):
def __init__(
self,
unfreeze_backbone_at_epoch: int = 10,
lambda_func: Optional[Callable] = None,
backbone_initial_ratio_lr: float = 0.1,
backbone_initial_lr: Optional[float] = None,
should_align: bool = True,
initial_denom_lr: float = 10.0,
train_bn: bool = True
):
"""
Fine-tuning callback for backbone networks.
Args:
unfreeze_backbone_at_epoch: Epoch to unfreeze backbone
lambda_func: Learning rate scheduling function
backbone_initial_ratio_lr: Initial backbone LR ratio
backbone_initial_lr: Initial backbone learning rate
should_align: Align learning rates
initial_denom_lr: Initial denominator for LR calculation
train_bn: Train batch normalization layers
"""Callbacks for monitoring training performance, throughput, and resource utilization.
class ThroughputMonitor(Callback):
def __init__(
self,
length_key: str = "seq_len",
batch_size_key: str = "batch_size",
window_size: int = 100
):
"""
Initialize ThroughputMonitor callback.
Args:
length_key: Key for sequence length in batch
batch_size_key: Key for batch size
window_size: Window size for throughput calculation
"""
class DeviceStatsMonitor(Callback):
def __init__(self, cpu_stats: Optional[bool] = None):
"""
Initialize DeviceStatsMonitor callback.
Args:
cpu_stats: Monitor CPU statistics
"""
class Timer(Callback):
def __init__(self, duration: Optional[Union[str, timedelta]] = None, interval: str = "step"):
"""
Initialize Timer callback for training duration control.
Args:
duration: Maximum training duration
interval: 'step' or 'epoch' for timing
"""class LambdaCallback(Callback):
def __init__(self, **kwargs):
"""
Create callback from lambda functions.
Args:
**kwargs: Mapping of hook names to functions
"""from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
# Configure callbacks
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='./checkpoints',
filename='model-{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min'
)
early_stopping = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
# Use callbacks in trainer
trainer = Trainer(
callbacks=[checkpoint_callback, early_stopping],
max_epochs=100
)import lightning as L
class MetricLoggingCallback(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Log custom metrics at end of each epoch
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
# Custom logging logic
if 'train_loss' in metrics:
print(f"Epoch {epoch}: Train Loss = {metrics['train_loss']:.4f}")
# Save metrics to file
with open('metrics.log', 'a') as f:
f.write(f"Epoch {epoch}: {dict(metrics)}\n")
# Use custom callback
trainer = Trainer(callbacks=[MetricLoggingCallback()])Install with Tessl CLI
npx tessl i tessl/pypi-lightning