CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-timm

PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks

Overview
Eval results
Files

features.mddocs/

Feature Extraction

Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.

Capabilities

Feature Extractor Creation

Create feature extractors that can extract intermediate representations from any layer of a model.

def create_feature_extractor(
    model: torch.nn.Module,
    return_nodes: Union[Dict[str, str], List[str]],
    train_return_nodes: Union[Dict[str, str], List[str]] = None,
    suppress_diff_warnings: bool = False,
    tracer_kwargs: Dict[str, Any] = None,
    **kwargs
):
    """
    Create a feature extractor from a model using FX graph tracing.

    Args:
        model: Source model to extract features from
        return_nodes: Nodes to return features from. Can be dict mapping 
                     node names to output names, or list of node names
        train_return_nodes: Different nodes for training mode
        suppress_diff_warnings: Suppress warnings about model differences
        tracer_kwargs: Additional arguments for FX tracer
        **kwargs: Additional arguments

    Returns:
        Feature extractor model that returns specified intermediate features
    """

def get_graph_node_names(
    model: torch.nn.Module,
    tracer_kwargs: Dict[str, Any] = None,
    suppress_diff_warnings: bool = False
) -> Tuple[List[str], List[str]]:
    """
    Get node names from model's FX graph for feature extraction.

    Args:
        model: Model to analyze
        tracer_kwargs: Additional tracer arguments
        suppress_diff_warnings: Suppress model difference warnings

    Returns:
        Tuple of (node_names, node_types) for available extraction points
    """

Feature Extraction Classes

Hook-Based Feature Extraction

class FeatureInfo:
    """
    Information about extracted features.

    Args:
        feature_info: List of feature information dictionaries
        out_indices: Output indices for features
    """
    
    def __init__(
        self,
        feature_info: List[Dict[str, Any]],
        out_indices: List[int]
    ): ...
    
    def get_dicts(self, keys: List[str] = None) -> List[Dict[str, Any]]:
        """Get feature info as list of dictionaries."""
    
    def channels(self, idx: int = None) -> Union[List[int], int]:
        """Get feature channels."""
    
    def reduction(self, idx: int = None) -> Union[List[int], int]:
        """Get feature reduction factors."""

class FeatureHooks:
    """
    Feature extraction using forward hooks.

    Args:
        hooks: List of hook functions
        named_modules: Dictionary of named modules
        out_map: Output mapping for feature names
    """
    
    def __init__(
        self,
        hooks: List[Callable],
        named_modules: Dict[str, torch.nn.Module],
        out_map: List[int] = None
    ): ...
    
    def get_output(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Get hooked features from forward pass."""

class FeatureHookNet(torch.nn.Module):
    """
    Wrapper that uses hooks to extract features during forward pass.

    Args:
        model: Base model to wrap
        out_indices: Indices of layers to extract features from
        out_map: Mapping of output names
        return_interm: Return intermediate features
        **kwargs: Additional arguments
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        out_indices: List[int],
        out_map: List[str] = None,
        return_interm: bool = False,
        **kwargs
    ): ...

class FeatureListNet(torch.nn.Module):
    """
    Wrapper that returns features as a list.

    Args:
        model: Base model to wrap
        out_indices: Indices of layers to extract features from
        **kwargs: Additional arguments
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        out_indices: List[int],
        **kwargs
    ): ...

class FeatureDictNet(torch.nn.Module):
    """
    Wrapper that returns features as a dictionary.

    Args:
        model: Base model to wrap
        out_indices: Indices of layers to extract features from
        out_map: Names for output features
        **kwargs: Additional arguments
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        out_indices: List[int],
        out_map: List[str] = None,
        **kwargs
    ): ...

FX-Based Feature Extraction

class FeatureGraphNet(torch.nn.Module):
    """
    FX-based feature extraction network.

    Args:
        model: Base model
        out_indices: Output layer indices
        out_map: Feature name mapping
        **kwargs: Additional arguments
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        out_indices: List[int],
        out_map: List[str] = None,
        **kwargs
    ): ...

class GraphExtractNet(torch.nn.Module):
    """
    Graph-based feature extraction using FX.

    Args:
        model: Source model
        return_nodes: Nodes to extract features from
        **kwargs: Additional arguments
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        return_nodes: Dict[str, str],
        **kwargs
    ): ...

Model Manipulation

Model Analysis and Modification

def model_parameters(
    model: torch.nn.Module,
    exclude_head: bool = False,
    recurse: bool = True
) -> Iterator[torch.nn.Parameter]:
    """
    Get model parameters with filtering options.

    Args:
        model: Model to analyze
        exclude_head: Exclude classifier/head parameters
        recurse: Recurse into submodules

    Returns:
        Iterator over model parameters
    """

def named_apply(
    fn: Callable,
    module: torch.nn.Module,
    name: str = '',
    depth_first: bool = True,
    include_root: bool = False
) -> torch.nn.Module:
    """
    Apply function to named modules recursively.

    Args:
        fn: Function to apply to each module
        module: Root module
        name: Current module name
        depth_first: Apply depth-first traversal
        include_root: Include root module

    Returns:
        Modified module
    """

def named_modules(
    module: torch.nn.Module,
    memo: set = None,
    prefix: str = '',
    remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Module]]:
    """
    Get named modules with filtering.

    Args:
        module: Root module
        memo: Set for tracking duplicates
        prefix: Name prefix
        remove_duplicate: Remove duplicate modules

    Returns:
        Iterator of (name, module) pairs
    """

def group_modules(
    module: torch.nn.Module,
    group_matcher: Callable,
    output_values: bool = False,
    reverse: bool = False
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Module]]]:
    """
    Group modules by matching criteria.

    Args:
        module: Module to group
        group_matcher: Function to determine group membership
        output_values: Return module objects instead of names
        reverse: Reverse the grouping order

    Returns:
        Dictionary mapping group IDs to module names/objects
    """

def group_parameters(
    module: torch.nn.Module,
    group_matcher: Callable,
    output_values: bool = False,
    reverse: bool = False
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Parameter]]]:
    """
    Group parameters by matching criteria.

    Args:
        module: Module to analyze
        group_matcher: Function to determine group membership
        output_values: Return parameter objects instead of names
        reverse: Reverse the grouping order

    Returns:
        Dictionary mapping group IDs to parameter names/objects
    """

def checkpoint_seq(
    functions: List[Callable],
    segments: int = 1,
    input: torch.Tensor = None,
    **kwargs
) -> torch.Tensor:
    """
    Apply gradient checkpointing to sequence of functions.

    Args:
        functions: List of functions to apply
        segments: Number of checkpoint segments
        input: Input tensor
        **kwargs: Additional arguments

    Returns:
        Output tensor with gradient checkpointing applied
    """

Model Adaptation

def adapt_input_conv(
    model: torch.nn.Module,
    in_chans: int,
    conv_layer: str = None
) -> torch.nn.Module:
    """
    Adapt model's input convolution for different channel counts.

    Args:
        model: Model to adapt
        in_chans: New number of input channels
        conv_layer: Name of convolution layer to adapt

    Returns:
        Model with adapted input convolution
    """

def load_pretrained(
    model: torch.nn.Module,
    cfg: Dict[str, Any] = None,
    num_classes: int = 1000,
    in_chans: int = 3,
    filter_fn: Callable = None,
    strict: bool = True,
    progress: bool = False
) -> None:
    """
    Load pretrained weights into model.

    Args:
        model: Model to load weights into
        cfg: Pretrained configuration
        num_classes: Number of output classes
        in_chans: Number of input channels
        filter_fn: Function to filter state dict keys
        strict: Strict loading mode
        progress: Show download progress
    """

def load_custom_pretrained(
    model: torch.nn.Module,
    cfg: Dict[str, Any] = None,
    load_fn: Callable = None,
    progress: bool = False,
    check_hash: bool = False
) -> None:
    """
    Load custom pretrained weights.

    Args:
        model: Model to load weights into
        cfg: Configuration dictionary
        load_fn: Custom loading function
        progress: Show progress
        check_hash: Verify file hash
    """

def build_model_with_cfg(
    model_cls: Callable,
    variant: str,
    pretrained: bool,
    pretrained_cfg: Dict[str, Any],
    model_cfg: Dict[str, Any],
    feature_cfg: Dict[str, Any],
    **kwargs
) -> torch.nn.Module:
    """
    Build model with configuration.

    Args:
        model_cls: Model class constructor
        variant: Model variant name
        pretrained: Load pretrained weights
        pretrained_cfg: Pretrained configuration
        model_cfg: Model configuration
        feature_cfg: Feature extraction configuration
        **kwargs: Additional model arguments

    Returns:
        Configured model instance
    """

State Dictionary Utilities

State Dict Manipulation

def clean_state_dict(
    state_dict: Dict[str, Any],
    model: torch.nn.Module = None
) -> Dict[str, Any]:
    """
    Clean state dictionary by removing unwanted keys.

    Args:
        state_dict: State dictionary to clean
        model: Model to match against

    Returns:
        Cleaned state dictionary
    """

def load_state_dict(
    checkpoint_path: str,
    use_ema: bool = True,
    device: torch.device = 'cpu'
) -> Dict[str, Any]:
    """
    Load state dictionary from checkpoint file.

    Args:
        checkpoint_path: Path to checkpoint file
        use_ema: Use EMA weights if available
        device: Device to load tensors on

    Returns:
        Loaded state dictionary
    """

def load_checkpoint(
    model: torch.nn.Module,
    checkpoint_path: str,
    use_ema: bool = False,
    device: torch.device = 'cpu',
    strict: bool = True
) -> None:
    """
    Load checkpoint into model.

    Args:
        model: Model to load checkpoint into
        checkpoint_path: Path to checkpoint file
        use_ema: Use EMA weights if available
        device: Device for loading
        strict: Strict loading mode
    """

def remap_state_dict(
    state_dict: Dict[str, Any],
    remap_dict: Dict[str, str]
) -> Dict[str, Any]:
    """
    Remap state dictionary keys using mapping rules.

    Args:
        state_dict: Original state dictionary
        remap_dict: Mapping from old keys to new keys

    Returns:
        Remapped state dictionary
    """

def resume_checkpoint(
    model: torch.nn.Module,
    checkpoint_path: str,
    optimizer: torch.optim.Optimizer = None,
    loss_scaler = None,
    log_info: bool = True
) -> Dict[str, Any]:
    """
    Resume training from checkpoint.

    Args:
        model: Model to resume
        checkpoint_path: Path to checkpoint
        optimizer: Optimizer to resume
        loss_scaler: Loss scaler to resume
        log_info: Log resume information

    Returns:
        Dictionary with resume information
    """

Usage Examples

Basic Feature Extraction

import timm
from timm.models import create_feature_extractor

# Create a model
model = timm.create_model('resnet50', pretrained=True)

# Create feature extractor for specific layers
feature_extractor = create_feature_extractor(
    model,
    return_nodes={
        'layer1': 'feat1',
        'layer2': 'feat2', 
        'layer3': 'feat3',
        'layer4': 'feat4'
    }
)

# Extract features
import torch
x = torch.randn(1, 3, 224, 224)
features = feature_extractor(x)
print(f"Feature shapes: {[(k, v.shape) for k, v in features.items()]}")

Hook-Based Feature Extraction

from timm.models import FeatureListNet

# Create model that returns features as list
model = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)

# Or wrap existing model
base_model = timm.create_model('resnet34', pretrained=True)
feature_model = FeatureListNet(base_model, out_indices=[1, 2, 3, 4])

# Extract features
features = feature_model(x)
print(f"Number of feature maps: {len(features)}")
for i, feat in enumerate(features):
    print(f"Feature {i}: {feat.shape}")

Model Analysis

from timm.models import get_graph_node_names, model_parameters

# Analyze model structure
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Get available nodes for feature extraction
node_names, node_types = get_graph_node_names(model)
print(f"Available nodes: {len(node_names)}")
print(f"Sample nodes: {node_names[:10]}")

# Count parameters
total_params = sum(p.numel() for p in model_parameters(model))
print(f"Total parameters: {total_params:,}")

# Count parameters excluding head
body_params = sum(p.numel() for p in model_parameters(model, exclude_head=True))
print(f"Body parameters: {body_params:,}")

Model Adaptation

from timm.models import adapt_input_conv, load_checkpoint

# Adapt model for different input channels (e.g., grayscale)
model = timm.create_model('resnet50', pretrained=True)
model = adapt_input_conv(model, in_chans=1)

# Load custom checkpoint
load_checkpoint(model, 'path/to/checkpoint.pth')

# Resume training
checkpoint_info = resume_checkpoint(
    model,
    'path/to/checkpoint.pth', 
    optimizer=optimizer,
    log_info=True
)
start_epoch = checkpoint_info['epoch']

Advanced Feature Configuration

# Create model with specific feature configuration
model = timm.create_model(
    'resnet50',
    pretrained=True,
    features_only=True,
    out_indices=[1, 2, 3, 4],  # Which stages to output
    output_stride=16,          # Overall output stride
    global_pool='',            # Disable global pooling
    num_classes=0              # Remove classifier
)

# Get feature info
feature_info = model.feature_info.get_dicts()
for info in feature_info:
    print(f"Layer: {info['module']}, Channels: {info['num_chs']}, Reduction: {info['reduction']}")

Types

from typing import Optional, Union, List, Dict, Callable, Any, Tuple, Iterator
import torch

# Feature extraction types
FeatureDict = Dict[str, torch.Tensor]
FeatureList = List[torch.Tensor]
NodeSpec = Union[Dict[str, str], List[str]]

# Model analysis types
ParameterIterator = Iterator[torch.nn.Parameter]
ModuleDict = Dict[str, torch.nn.Module]
ParameterDict = Dict[str, torch.nn.Parameter]

# State dict types
StateDict = Dict[str, Any]
RemapDict = Dict[str, str]

# Hook types
HookFunction = Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], None]
FilterFunction = Callable[[str, torch.nn.Parameter], bool]

Install with Tessl CLI

npx tessl i tessl/pypi-timm

docs

data.md

features.md

index.md

layers.md

models.md

training.md

utils.md

tile.json