CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-fastai

fastai simplifies training fast and accurate neural nets using modern best practices

Pending
Overview
Eval results
Files

callbacks.mddocs/

Callbacks and Training Customization

Extensive callback system for customizing the training loop including progress tracking, learning rate scheduling, regularization, logging, and advanced training techniques.

Capabilities

Core Callback Infrastructure

Base classes and essential callbacks that form the foundation of fastai's training system.

class Callback:
    """
    Base class for training callbacks.
    Callbacks can hook into different points of the training loop.
    """
    
    def __init__(self): ...
    
    def before_fit(self):
        """Called before training starts."""
    
    def before_epoch(self):
        """Called before each epoch."""
    
    def before_train(self):
        """Called before training phase of epoch."""
    
    def before_batch(self):
        """Called before each batch."""
    
    def after_pred(self):
        """Called after model prediction."""
    
    def after_loss(self):
        """Called after loss computation."""
    
    def before_backward(self):
        """Called before backward pass."""
    
    def after_backward(self):
        """Called after backward pass."""
    
    def after_step(self):
        """Called after optimizer step."""
    
    def after_cancel_batch(self):
        """Called if batch is cancelled."""
    
    def after_batch(self):
        """Called after each batch."""
    
    def after_cancel_train(self):
        """Called if training is cancelled."""
    
    def after_train(self):
        """Called after training phase."""
    
    def before_validate(self):
        """Called before validation phase."""
    
    def after_cancel_validate(self):
        """Called if validation is cancelled."""
    
    def after_validate(self):
        """Called after validation phase."""
    
    def after_cancel_epoch(self):
        """Called if epoch is cancelled."""
    
    def after_epoch(self):
        """Called after each epoch."""
    
    def after_cancel_fit(self):
        """Called if training is cancelled."""
    
    def after_fit(self):
        """Called after training completes."""

class TrainEvalCallback(Callback):
    """Handle switching between training and evaluation modes."""
    
    def before_fit(self): ...
    def before_train(self): ...
    def before_validate(self): ...

class Recorder(Callback):
    """Record training statistics and metrics."""
    
    def before_fit(self): ...
    def after_batch(self): ...
    def after_epoch(self): ...
    
    def plot_loss(self, skip_start=5, with_valid=True): ...
    def plot_sched(self, keys=None, figsize=None): ...

Learning Rate Scheduling

Callbacks for sophisticated learning rate scheduling and optimization.

class OneCycleTraining(Callback):
    """
    One cycle learning rate policy for super-convergence.
    Cycles learning rate from low to high and back to low.
    """
    
    def __init__(self, max_lr=None, div_factor=25.0, final_div=None, 
                 pct_start=0.25, anneal_strategy='cos', cycle_momentum=True, 
                 base_momentum=0.85, max_momentum=0.95, wd=None, 
                 moms=None, **kwargs):
        """
        Initialize one cycle training.
        
        Parameters:
        - max_lr: Maximum learning rate
        - div_factor: Initial LR divisor (max_lr/div_factor)
        - final_div: Final LR divisor
        - pct_start: Percentage of cycle for warmup
        - anneal_strategy: 'cos' or 'linear' annealing
        - cycle_momentum: Cycle momentum inverse to LR
        - base_momentum: Minimum momentum value
        - max_momentum: Maximum momentum value
        - wd: Weight decay
        - moms: Custom momentum schedule
        """

class ReduceLROnPlateau(Callback):
    """Reduce learning rate when metric stops improving."""
    
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0, 
                 patience=1, factor=0.2, min_lr=0, reset_on_fit=True):
        """
        Initialize learning rate reduction on plateau.
        
        Parameters:
        - monitor: Metric to monitor
        - comp: Comparison function (np.less for loss, np.greater for accuracy)
        - min_delta: Minimum change to qualify as improvement
        - patience: Epochs to wait before reducing
        - factor: Factor to reduce LR by
        - min_lr: Minimum learning rate
        - reset_on_fit: Reset patience counter on new fit
        """

class LRFinder(Callback):
    """Learning rate finder for optimal LR discovery."""
    
    def __init__(self, start_lr=1e-7, end_lr=10, num_it=100, step_mode='exp'): ...

Training Enhancement Callbacks

Callbacks that enhance training stability and performance.

class MixedPrecision(Callback):
    """
    Automatic mixed precision training for faster training with lower memory usage.
    Uses float16 for forward pass and float32 for gradients.
    """
    
    def __init__(self, loss_scale=512, flat_master=False, dynamic=True, 
                 clip=None, eps=1e-5, scale_wait=500): ...

class GradientClip(Callback):
    """Gradient clipping for training stability."""
    
    def __init__(self, max_norm=1.0, norm_type=2.0): ...

class GradientAccumulation(Callback):
    """Accumulate gradients over multiple batches before optimizer step."""
    
    def __init__(self, n_acc=32): ...

class BnFreeze(Callback):
    """Freeze batch normalization layers during training."""
    
    def before_epoch(self): ...

Monitoring and Logging

Callbacks for tracking training progress and logging to external services.

class ProgressCallback(Callback):
    """Display training progress with progress bars."""
    
    def __init__(self, plot=False, display=True): ...
    
    def before_fit(self): ...
    def after_batch(self): ...
    def after_epoch(self): ...

class CSVLogger(Callback):
    """Log training metrics to CSV file."""
    
    def __init__(self, fname='history.csv', append=False): ...
    
    def after_epoch(self): ...

class TensorBoardCallback(Callback):
    """Log metrics and model graph to TensorBoard."""
    
    def __init__(self, log_dir=None, trace_model=True, log_preds=True, 
                 n_preds=9, projector=False): ...
    
    def before_fit(self): ...
    def after_epoch(self): ...
    def after_fit(self): ...

class WandbCallback(Callback):
    """Integration with Weights & Biases experiment tracking."""
    
    def __init__(self, log_preds=True, log_model=True, log_dataset=False, 
                 dataset_name=None, valid_idx=1, n_preds=36, seed=12345): ...
    
    def before_fit(self): ...
    def after_epoch(self): ...
    def after_fit(self): ...

class CometCallback(Callback):
    """Integration with Comet.ml experiment tracking."""
    
    def __init__(self, log_model=True, log_dataset=False, project_name=None, 
                 log_code=True, log_preds=True, n_preds=9): ...
    
    def before_fit(self): ...
    def after_epoch(self): ...

Model Management Callbacks

Callbacks for saving, loading, and managing model checkpoints.

class SaveModelCallback(Callback):
    """Save model checkpoints during training."""
    
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0, 
                 fname='bestmodel', every_epoch=False, at_end=False, 
                 with_opt=False, reset_on_fit=True):
        """
        Initialize model saving callback.
        
        Parameters:
        - monitor: Metric to monitor for best model
        - comp: Comparison function (np.less for loss)
        - min_delta: Minimum improvement required
        - fname: Filename for saved model
        - every_epoch: Save every epoch
        - at_end: Save at end of training
        - with_opt: Include optimizer state
        - reset_on_fit: Reset best metric on new fit
        """

class EarlyStoppingCallback(Callback):
    """Stop training early when metric stops improving."""
    
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0, 
                 patience=1, restore_best_weights=True, reset_on_fit=True):
        """
        Initialize early stopping.
        
        Parameters:
        - monitor: Metric to monitor
        - comp: Comparison function
        - min_delta: Minimum improvement
        - patience: Epochs to wait
        - restore_best_weights: Restore best weights when stopping
        - reset_on_fit: Reset counter on new fit
        """

Regularization and Augmentation

Callbacks implementing regularization techniques and data augmentation.

class MixUp(Callback):
    """
    MixUp data augmentation during training.
    Combines pairs of examples and their labels.
    """
    
    def __init__(self, alpha=0.4, stack_x=False, stack_y=True): ...
    
    def before_batch(self): ...

class CutMix(Callback):
    """
    CutMix augmentation combining spatial mixing with MixUp.
    Cuts and pastes patches between training images.
    """
    
    def __init__(self, alpha=1.0): ...
    
    def before_batch(self): ...

class RNNRegularizer(Callback):
    """Regularization techniques specific to RNN models."""
    
    def __init__(self, alpha=2, beta=1, **kwargs): ...

class ChannelsLast(Callback):
    """Memory layout optimization for CNNs."""
    
    def before_fit(self): ...
    def before_batch(self): ...

Advanced Training Techniques

Callbacks implementing advanced training strategies and techniques.

class LabelSmoothingCrossEntropy(Callback):
    """Label smoothing regularization technique."""
    
    def __init__(self, eps=0.1, reduction='mean'): ...

class SelfDistillation(Callback):
    """Self-distillation training technique."""
    
    def __init__(self, temperature=3.0, alpha=0.7): ...

class Lookahead(Callback):
    """Lookahead optimizer wrapper."""
    
    def __init__(self, k=5, alpha=0.5): ...

class FreezeCallback(Callback):
    """Freeze/unfreeze model layers during training."""
    
    def __init__(self, freeze_epochs=1): ...
    
    def before_epoch(self): ...

class ShowGraphCallback(Callback):
    """Visualize model architecture and training graphs."""
    
    def after_fit(self): ...

Custom Callback Utilities

Utilities for creating and managing custom callbacks.

def callback_handler(cbs=None, **kwargs):
    """Create callback handler with list of callbacks."""

class CallbackHandler:
    """Handler that manages and calls multiple callbacks."""
    
    def __init__(self, cbs=None): ...
    
    def add_cb(self, cb): ...
    def remove_cb(self, cb): ...
    def __call__(self, event_name): ...

class CancelFitException(Exception):
    """Exception to cancel training."""

class CancelEpochException(Exception):
    """Exception to cancel current epoch."""

class CancelTrainException(Exception):
    """Exception to cancel training phase."""

class CancelValidException(Exception):
    """Exception to cancel validation phase."""

class CancelBatchException(Exception):
    """Exception to cancel current batch."""

Training Control and Debugging Callbacks

Advanced callbacks for training control, debugging, and model analysis.

class TerminateOnNaNCallback(Callback):
    """
    Automatically terminate training if loss becomes NaN or infinite.
    Essential for robust training pipelines.
    """
    order = -9
    
    def after_batch(self):
        """Test if loss is NaN/inf and interrupt training."""

class ShortEpochCallback(Callback):
    """
    Fit only a percentage of an epoch for debugging/testing.
    
    Parameters:
    - pct: Percentage of epoch to train (0.01 = 1%)
    - short_valid: Whether to also shorten validation
    """
    def __init__(self, pct=0.01, short_valid=True): ...

class CollectDataCallback(Callback):
    """
    Collect all batches with predictions and losses for debugging.
    Useful for analyzing model behavior and debugging issues.
    """
    def before_fit(self): ...
    def after_batch(self): ...

Model Analysis and Hook Callbacks

Callbacks for analyzing model internals and registering hooks on model layers.

class ActivationStats(HookCallback):
    """
    Record activation statistics (mean, std, near-zero percentage) during training.
    Essential for debugging vanishing/exploding gradients and dead neurons.
    
    Parameters:
    - with_hist: Whether to record activation histograms
    """
    order = -20
    
    def __init__(self, with_hist=False, **kwargs): ...
    def layer_stats(self, idx): ...
    def hist(self, idx): ...
    def color_dim(self, idx, figsize=(10,5)): ...
    def plot_layer_stats(self, idx): ...

class HookCallback(Callback):
    """
    Base callback for registering hooks on model modules.
    Foundation for advanced model introspection and analysis.
    
    Parameters:
    - modules: Specific modules to hook (None = all with params)
    - every: Register hooks every N training iterations
    - remove_end: Remove hooks after training
    - is_forward: Forward vs backward hooks
    - detach: Detach tensors from computation graph
    - cpu: Move hooked data to CPU
    - include_paramless: Include modules without parameters
    """
    def __init__(self, modules=None, every=None, remove_end=True, 
                 is_forward=True, detach=True, cpu=True, 
                 include_paramless=False): ...

RNN-Specific Callbacks

Specialized callbacks for training recurrent neural networks and sequence models.

class ModelResetter(Callback):
    """
    Reset RNN hidden states between training/validation phases.
    Essential for proper RNN training with stateful hidden states.
    """
    def before_train(self): ...
    def before_validate(self): ...
    def after_fit(self): ...

class RNNCallback(Callback):
    """
    Handle RNN outputs and save raw/dropout outputs for regularization.
    Manages the complexities of RNN training loops.
    """
    def after_pred(self): ...

Advanced Prediction and Uncertainty Callbacks

Callbacks for enhanced prediction gathering and uncertainty estimation.

class MCDropoutCallback(Callback):
    """
    Enable Monte Carlo Dropout for uncertainty estimation.
    Keeps dropout layers active during validation for probabilistic predictions.
    """
    def before_validate(self): ...
    def after_validate(self): ...

class FetchPredsCallback(Callback):
    """
    Fetch predictions during training loop with callback management.
    
    Parameters:
    - ds_idx: Dataset index (0=train, 1=valid)
    - dl: Custom DataLoader for predictions
    - with_decoded: Return decoded predictions
    - cbs: Callbacks to temporarily remove
    - reorder: Sort prediction results
    """
    def __init__(self, ds_idx=1, dl=None, with_input=False, 
                 with_decoded=False, cbs=None, reorder=True): ...

Advanced Mixed Precision Training

Enhanced mixed precision training with fine-grained control over scaling and gradients.

class NonNativeMixedPrecision(Callback):
    """
    Manual mixed precision implementation for advanced control.
    Provides more flexibility than PyTorch's native automatic mixed precision.
    
    Parameters:
    - loss_scale: Loss scaling factor for gradient stability
    - flat_master: Flatten fp32 parameters for performance
    - dynamic: Automatic loss scale adjustment
    - max_loss_scale: Maximum loss scale value
    - div_factor: Scale adjustment factor
    - scale_wait: Batches to wait before scale increase
    - clip: Gradient clipping value
    """
    order = 10
    
    def __init__(self, loss_scale=512, flat_master=False, dynamic=True,
                 max_loss_scale=2.**24, div_factor=2., scale_wait=500, clip=None): ...

Integration and Production Callbacks

Callbacks for integration with external platforms and production workflows.

class AzureMLCallback(Callback):
    """
    Integration with Azure Machine Learning for experiment tracking.
    Automatically logs metrics, parameters, and models to Azure ML.
    
    Parameters:
    - learn: Learner instance
    - log_model: Whether to log the trained model
    - model_name: Name for the logged model
    """
    def __init__(self, learn=None, log_model=False, model_name='model'): ...

class CaptumInterpretation:
    """
    Model interpretability using Facebook's Captum library.
    Provides advanced attribution and visualization methods.
    
    Parameters:
    - learn: Learner instance
    - cmap_name: Colormap name for visualizations
    - methods: Visualization methods
    - signs: Attribution signs to display
    """
    def __init__(self, learn, cmap_name='custom blue', colors=None, N=256,
                 methods=('original_image', 'heat_map'), signs=("all", "positive")): ...
    def visualize(self, inp, metric='IG', n_steps=1000, baseline_type='zeros'): ...
    def insights(self, inp_data, debug=True): ...

Install with Tessl CLI

npx tessl i tessl/pypi-fastai

docs

callbacks.md

collaborative-filtering.md

core-training.md

data-loading.md

index.md

interpretation.md

medical.md

metrics-losses.md

tabular.md

text.md

vision.md

tile.json