CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-fastai

fastai simplifies training fast and accurate neural nets using modern best practices

Pending
Overview
Eval results
Files

interpretation.mddocs/

Model Interpretation

Tools for understanding and interpreting model predictions including visualization utilities, analysis methods, and techniques for gaining insights into model behavior and decision-making processes.

Capabilities

Classification Interpretation

Comprehensive analysis tools for understanding classification model predictions and performance.

class ClassificationInterpretation:
    """
    Interpretation tools for classification models.
    Provides methods to analyze predictions, visualize confusion matrices,
    and identify model strengths and weaknesses.
    """
    
    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        """
        Create interpretation from trained learner.
        
        Parameters:
        - learn: Trained Learner instance
        - ds_idx: Dataset index (1 for validation)
        - dl: Custom DataLoader (uses learner's if None)
        - act: Activation function to apply to predictions
        
        Returns:
        - ClassificationInterpretation instance
        """
    
    def confusion_matrix(self, slice_size=1):
        """
        Compute confusion matrix for predictions.
        
        Parameters:
        - slice_size: Size of slice for memory management
        
        Returns:
        - Confusion matrix as tensor
        """
    
    def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', 
                             cmap="Blues", figsize=None, **kwargs):
        """
        Plot confusion matrix heatmap.
        
        Parameters:
        - normalize: Normalize confusion matrix
        - title: Plot title
        - cmap: Colormap for heatmap
        - figsize: Figure size
        - **kwargs: Additional plotting arguments
        """
    
    def most_confused(self, min_val=1):
        """
        Find most confused class pairs.
        
        Parameters:
        - min_val: Minimum confusion count to include
        
        Returns:
        - List of (actual, predicted, count) tuples sorted by confusion count
        """
    
    def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):
        """
        Plot examples with highest losses.
        
        Parameters:
        - k: Number of examples to show
        - largest: Show largest losses (vs smallest)
        - figsize: Figure size
        - **kwargs: Additional plotting arguments
        """
    
    def top_losses(self, k=None, largest=True):
        """
        Get examples with highest losses.
        
        Parameters:
        - k: Number of examples (all if None)
        - largest: Return largest losses (vs smallest)
        
        Returns:
        - Tuple of (losses, indices)
        """
    
    def print_classification_report(self):
        """Print detailed classification report with precision, recall, F1."""

Segmentation Interpretation

Specialized interpretation tools for segmentation models and pixel-level predictions.

class SegmentationInterpretation:
    """Interpretation tools for segmentation models."""
    
    @classmethod
    def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
        """Create segmentation interpretation from learner."""
    
    def plot_top_losses(self, k, largest=True, figsize=(12,12), **kwargs):
        """Plot segmentation examples with highest losses."""
    
    def confusion_matrix(self, slice_size=1):
        """Compute pixel-wise confusion matrix."""
    
    def plot_confusion_matrix(self, normalize=False, **kwargs):
        """Plot segmentation confusion matrix."""
    
    def per_class_accuracy(self):
        """Calculate accuracy for each segmentation class."""
    
    def intersection_over_union(self):
        """Calculate IoU for each class."""

Base Interpretation Classes

Foundation classes for building custom interpretation tools.

class Interpretation:
    """Base class for model interpretation."""
    
    def __init__(self, dl, inputs, preds, targs, decoded, losses):
        """
        Initialize interpretation.
        
        Parameters:
        - dl: DataLoader used for predictions
        - inputs: Model inputs
        - preds: Raw predictions
        - targs: Target values
        - decoded: Decoded predictions
        - losses: Loss values for each example
        """
    
    def top_losses(self, k=None, largest=True):
        """Get examples with highest/lowest losses."""
    
    def plot_top_losses(self, k, largest=True, **kwargs):
        """Plot examples with extreme losses."""

def plot_top_losses(interp, k, largest=True, **kwargs):
    """Utility function to plot top losses."""

Gradient-Based Interpretation

Methods using gradients to understand model decisions and feature importance.

class GradCAM:
    """
    Gradient-weighted Class Activation Mapping.
    Visualizes which parts of input are important for predictions.
    """
    
    def __init__(self, learn, layer=None):
        """
        Initialize GradCAM.
        
        Parameters:
        - learn: Trained learner
        - layer: Target layer for activation maps (last conv layer if None)
        """
    
    def __call__(self, x, class_idx=None):
        """
        Generate GradCAM heatmap.
        
        Parameters:
        - x: Input image
        - class_idx: Target class index (predicted class if None)
        
        Returns:
        - Heatmap showing important regions
        """

class IntegratedGradients:
    """
    Integrated Gradients for feature attribution.
    Computes gradients along straight-line path from baseline to input.
    """
    
    def __init__(self, learn, baseline=None):
        """
        Initialize Integrated Gradients.
        
        Parameters:
        - learn: Trained learner
        - baseline: Baseline input (zeros if None)
        """
    
    def attribute(self, x, target=None, n_steps=50):
        """
        Compute integrated gradients attribution.
        
        Parameters:
        - x: Input to analyze
        - target: Target class (predicted if None)
        - n_steps: Number of integration steps
        
        Returns:
        - Attribution map
        """

def gradient_times_input(learn, x, target=None):
    """Simple gradient * input attribution method."""

def saliency_map(learn, x, target=None):
    """Generate saliency map from gradients."""

Feature Importance Analysis

Tools for analyzing feature importance in different types of models.

class FeatureImportance:
    """Analyze feature importance for tabular models."""
    
    def __init__(self, learn):
        """Initialize with trained tabular learner."""
    
    def permutation_importance(self, dl=None, n_repeats=5, random_state=None):
        """
        Calculate permutation-based feature importance.
        
        Parameters:
        - dl: DataLoader (uses validation if None)
        - n_repeats: Number of permutation repeats
        - random_state: Random seed
        
        Returns:
        - Feature importance scores
        """
    
    def plot_importance(self, max_vars=20, figsize=(8,6)):
        """Plot feature importance scores."""

def rfpimp_importance(learn, dl=None):
    """Random forest-style permutation importance."""

def oob_score_importance(learn, dl=None):
    """Out-of-bag score-based importance."""

Prediction Analysis

Tools for analyzing and visualizing model predictions across different domains.

def plot_predictions(learn, ds_idx=1, max_n=9, figsize=None, **kwargs):
    """
    Plot model predictions with ground truth.
    
    Parameters:
    - learn: Trained learner
    - ds_idx: Dataset index
    - max_n: Maximum number of examples
    - figsize: Figure size
    - **kwargs: Additional plotting arguments
    """

def show_results(learn, ds_idx=1, dl=None, max_n=10, shuffle=True, **kwargs):
    """Show model results on dataset."""

class PredictionAnalyzer:
    """Analyze prediction patterns and model behavior."""
    
    def __init__(self, learn, dl=None):
        """Initialize analyzer with learner and data."""
    
    def prediction_distribution(self):
        """Analyze distribution of prediction scores."""
    
    def confidence_analysis(self):
        """Analyze prediction confidence patterns."""
    
    def error_analysis(self):
        """Analyze patterns in model errors."""

Visualization Utilities

Utility functions for creating informative visualizations of model behavior.

def plot_multi_losses(losses_list, labels=None, figsize=(12,8)):
    """Plot multiple loss curves for comparison."""

def plot_lr_find(learn, skip_start=5, skip_end=5, suggestion=True):
    """Plot learning rate finder results."""

def plot_metrics(learn, nrows=None, ncols=None, figsize=None):
    """Plot all tracked metrics."""

def show_batch_predictions(learn, dl=None, max_n=9, figsize=None, **kwargs):
    """Show batch with predictions overlaid."""

class ActivationStats:
    """Analyze activation statistics across model layers."""
    
    def __init__(self, learn):
        """Initialize with learner."""
    
    def stats_by_layer(self):
        """Get activation statistics for each layer."""
    
    def plot_layer_stats(self, figsize=(15,5)):
        """Plot activation statistics."""

def dead_chart(activs, figsize=(10,5)):
    """Chart showing dead neurons by layer."""

def hist_chart(activs, figsize=(10,5)):
    """Histogram of activations by layer."""

Model Debugging

Tools for debugging model architecture and training issues.

class ModelDebugger:
    """Debug model architecture and training issues."""
    
    def __init__(self, learn):
        """Initialize debugger with learner."""
    
    def check_gradient_flow(self):
        """Check for gradient flow issues."""
    
    def analyze_layer_outputs(self, x):
        """Analyze outputs from each layer."""
    
    def detect_dead_neurons(self):
        """Detect neurons that never activate."""
    
    def weight_distribution_analysis(self):
        """Analyze weight distributions across layers."""

def summary(learn, input_size=None):
    """Print model summary with layer details."""

def model_sizes(learn):
    """Analyze model memory usage by layer."""

def check_model(learn, lr=1e-3):  
    """Run model health checks."""

Interactive Interpretation

Tools for interactive exploration of model predictions and behavior.

class InteractiveClassifier:
    """Interactive widget for exploring classification predictions."""
    
    def __init__(self, learn, ds_idx=1):
        """Initialize interactive classifier."""
    
    def show(self):
        """Display interactive widget."""

class InteractiveSegmentation:
    """Interactive widget for exploring segmentation predictions."""
    
    def __init__(self, learn, ds_idx=1):
        """Initialize interactive segmentation explorer."""
    
    def show(self):
        """Display interactive widget."""

def create_interpretation_widget(learn, interpretation_type='classification'):
    """Create appropriate interpretation widget for model type."""

Install with Tessl CLI

npx tessl i tessl/pypi-fastai

docs

callbacks.md

collaborative-filtering.md

core-training.md

data-loading.md

index.md

interpretation.md

medical.md

metrics-losses.md

tabular.md

text.md

vision.md

tile.json