PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.
Create feature extractors that can extract intermediate representations from any layer of a model.
def create_feature_extractor(
model: torch.nn.Module,
return_nodes: Union[Dict[str, str], List[str]],
train_return_nodes: Union[Dict[str, str], List[str]] = None,
suppress_diff_warnings: bool = False,
tracer_kwargs: Dict[str, Any] = None,
**kwargs
):
"""
Create a feature extractor from a model using FX graph tracing.
Args:
model: Source model to extract features from
return_nodes: Nodes to return features from. Can be dict mapping
node names to output names, or list of node names
train_return_nodes: Different nodes for training mode
suppress_diff_warnings: Suppress warnings about model differences
tracer_kwargs: Additional arguments for FX tracer
**kwargs: Additional arguments
Returns:
Feature extractor model that returns specified intermediate features
"""
def get_graph_node_names(
model: torch.nn.Module,
tracer_kwargs: Dict[str, Any] = None,
suppress_diff_warnings: bool = False
) -> Tuple[List[str], List[str]]:
"""
Get node names from model's FX graph for feature extraction.
Args:
model: Model to analyze
tracer_kwargs: Additional tracer arguments
suppress_diff_warnings: Suppress model difference warnings
Returns:
Tuple of (node_names, node_types) for available extraction points
"""class FeatureInfo:
"""
Information about extracted features.
Args:
feature_info: List of feature information dictionaries
out_indices: Output indices for features
"""
def __init__(
self,
feature_info: List[Dict[str, Any]],
out_indices: List[int]
): ...
def get_dicts(self, keys: List[str] = None) -> List[Dict[str, Any]]:
"""Get feature info as list of dictionaries."""
def channels(self, idx: int = None) -> Union[List[int], int]:
"""Get feature channels."""
def reduction(self, idx: int = None) -> Union[List[int], int]:
"""Get feature reduction factors."""
class FeatureHooks:
"""
Feature extraction using forward hooks.
Args:
hooks: List of hook functions
named_modules: Dictionary of named modules
out_map: Output mapping for feature names
"""
def __init__(
self,
hooks: List[Callable],
named_modules: Dict[str, torch.nn.Module],
out_map: List[int] = None
): ...
def get_output(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Get hooked features from forward pass."""
class FeatureHookNet(torch.nn.Module):
"""
Wrapper that uses hooks to extract features during forward pass.
Args:
model: Base model to wrap
out_indices: Indices of layers to extract features from
out_map: Mapping of output names
return_interm: Return intermediate features
**kwargs: Additional arguments
"""
def __init__(
self,
model: torch.nn.Module,
out_indices: List[int],
out_map: List[str] = None,
return_interm: bool = False,
**kwargs
): ...
class FeatureListNet(torch.nn.Module):
"""
Wrapper that returns features as a list.
Args:
model: Base model to wrap
out_indices: Indices of layers to extract features from
**kwargs: Additional arguments
"""
def __init__(
self,
model: torch.nn.Module,
out_indices: List[int],
**kwargs
): ...
class FeatureDictNet(torch.nn.Module):
"""
Wrapper that returns features as a dictionary.
Args:
model: Base model to wrap
out_indices: Indices of layers to extract features from
out_map: Names for output features
**kwargs: Additional arguments
"""
def __init__(
self,
model: torch.nn.Module,
out_indices: List[int],
out_map: List[str] = None,
**kwargs
): ...class FeatureGraphNet(torch.nn.Module):
"""
FX-based feature extraction network.
Args:
model: Base model
out_indices: Output layer indices
out_map: Feature name mapping
**kwargs: Additional arguments
"""
def __init__(
self,
model: torch.nn.Module,
out_indices: List[int],
out_map: List[str] = None,
**kwargs
): ...
class GraphExtractNet(torch.nn.Module):
"""
Graph-based feature extraction using FX.
Args:
model: Source model
return_nodes: Nodes to extract features from
**kwargs: Additional arguments
"""
def __init__(
self,
model: torch.nn.Module,
return_nodes: Dict[str, str],
**kwargs
): ...def model_parameters(
model: torch.nn.Module,
exclude_head: bool = False,
recurse: bool = True
) -> Iterator[torch.nn.Parameter]:
"""
Get model parameters with filtering options.
Args:
model: Model to analyze
exclude_head: Exclude classifier/head parameters
recurse: Recurse into submodules
Returns:
Iterator over model parameters
"""
def named_apply(
fn: Callable,
module: torch.nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False
) -> torch.nn.Module:
"""
Apply function to named modules recursively.
Args:
fn: Function to apply to each module
module: Root module
name: Current module name
depth_first: Apply depth-first traversal
include_root: Include root module
Returns:
Modified module
"""
def named_modules(
module: torch.nn.Module,
memo: set = None,
prefix: str = '',
remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.nn.Module]]:
"""
Get named modules with filtering.
Args:
module: Root module
memo: Set for tracking duplicates
prefix: Name prefix
remove_duplicate: Remove duplicate modules
Returns:
Iterator of (name, module) pairs
"""
def group_modules(
module: torch.nn.Module,
group_matcher: Callable,
output_values: bool = False,
reverse: bool = False
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Module]]]:
"""
Group modules by matching criteria.
Args:
module: Module to group
group_matcher: Function to determine group membership
output_values: Return module objects instead of names
reverse: Reverse the grouping order
Returns:
Dictionary mapping group IDs to module names/objects
"""
def group_parameters(
module: torch.nn.Module,
group_matcher: Callable,
output_values: bool = False,
reverse: bool = False
) -> Union[Dict[int, List[str]], Dict[int, List[torch.nn.Parameter]]]:
"""
Group parameters by matching criteria.
Args:
module: Module to analyze
group_matcher: Function to determine group membership
output_values: Return parameter objects instead of names
reverse: Reverse the grouping order
Returns:
Dictionary mapping group IDs to parameter names/objects
"""
def checkpoint_seq(
functions: List[Callable],
segments: int = 1,
input: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Apply gradient checkpointing to sequence of functions.
Args:
functions: List of functions to apply
segments: Number of checkpoint segments
input: Input tensor
**kwargs: Additional arguments
Returns:
Output tensor with gradient checkpointing applied
"""def adapt_input_conv(
model: torch.nn.Module,
in_chans: int,
conv_layer: str = None
) -> torch.nn.Module:
"""
Adapt model's input convolution for different channel counts.
Args:
model: Model to adapt
in_chans: New number of input channels
conv_layer: Name of convolution layer to adapt
Returns:
Model with adapted input convolution
"""
def load_pretrained(
model: torch.nn.Module,
cfg: Dict[str, Any] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Callable = None,
strict: bool = True,
progress: bool = False
) -> None:
"""
Load pretrained weights into model.
Args:
model: Model to load weights into
cfg: Pretrained configuration
num_classes: Number of output classes
in_chans: Number of input channels
filter_fn: Function to filter state dict keys
strict: Strict loading mode
progress: Show download progress
"""
def load_custom_pretrained(
model: torch.nn.Module,
cfg: Dict[str, Any] = None,
load_fn: Callable = None,
progress: bool = False,
check_hash: bool = False
) -> None:
"""
Load custom pretrained weights.
Args:
model: Model to load weights into
cfg: Configuration dictionary
load_fn: Custom loading function
progress: Show progress
check_hash: Verify file hash
"""
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
pretrained_cfg: Dict[str, Any],
model_cfg: Dict[str, Any],
feature_cfg: Dict[str, Any],
**kwargs
) -> torch.nn.Module:
"""
Build model with configuration.
Args:
model_cls: Model class constructor
variant: Model variant name
pretrained: Load pretrained weights
pretrained_cfg: Pretrained configuration
model_cfg: Model configuration
feature_cfg: Feature extraction configuration
**kwargs: Additional model arguments
Returns:
Configured model instance
"""def clean_state_dict(
state_dict: Dict[str, Any],
model: torch.nn.Module = None
) -> Dict[str, Any]:
"""
Clean state dictionary by removing unwanted keys.
Args:
state_dict: State dictionary to clean
model: Model to match against
Returns:
Cleaned state dictionary
"""
def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: torch.device = 'cpu'
) -> Dict[str, Any]:
"""
Load state dictionary from checkpoint file.
Args:
checkpoint_path: Path to checkpoint file
use_ema: Use EMA weights if available
device: Device to load tensors on
Returns:
Loaded state dictionary
"""
def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
use_ema: bool = False,
device: torch.device = 'cpu',
strict: bool = True
) -> None:
"""
Load checkpoint into model.
Args:
model: Model to load checkpoint into
checkpoint_path: Path to checkpoint file
use_ema: Use EMA weights if available
device: Device for loading
strict: Strict loading mode
"""
def remap_state_dict(
state_dict: Dict[str, Any],
remap_dict: Dict[str, str]
) -> Dict[str, Any]:
"""
Remap state dictionary keys using mapping rules.
Args:
state_dict: Original state dictionary
remap_dict: Mapping from old keys to new keys
Returns:
Remapped state dictionary
"""
def resume_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
optimizer: torch.optim.Optimizer = None,
loss_scaler = None,
log_info: bool = True
) -> Dict[str, Any]:
"""
Resume training from checkpoint.
Args:
model: Model to resume
checkpoint_path: Path to checkpoint
optimizer: Optimizer to resume
loss_scaler: Loss scaler to resume
log_info: Log resume information
Returns:
Dictionary with resume information
"""import timm
from timm.models import create_feature_extractor
# Create a model
model = timm.create_model('resnet50', pretrained=True)
# Create feature extractor for specific layers
feature_extractor = create_feature_extractor(
model,
return_nodes={
'layer1': 'feat1',
'layer2': 'feat2',
'layer3': 'feat3',
'layer4': 'feat4'
}
)
# Extract features
import torch
x = torch.randn(1, 3, 224, 224)
features = feature_extractor(x)
print(f"Feature shapes: {[(k, v.shape) for k, v in features.items()]}")from timm.models import FeatureListNet
# Create model that returns features as list
model = timm.create_model('efficientnet_b0', pretrained=True, features_only=True)
# Or wrap existing model
base_model = timm.create_model('resnet34', pretrained=True)
feature_model = FeatureListNet(base_model, out_indices=[1, 2, 3, 4])
# Extract features
features = feature_model(x)
print(f"Number of feature maps: {len(features)}")
for i, feat in enumerate(features):
print(f"Feature {i}: {feat.shape}")from timm.models import get_graph_node_names, model_parameters
# Analyze model structure
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# Get available nodes for feature extraction
node_names, node_types = get_graph_node_names(model)
print(f"Available nodes: {len(node_names)}")
print(f"Sample nodes: {node_names[:10]}")
# Count parameters
total_params = sum(p.numel() for p in model_parameters(model))
print(f"Total parameters: {total_params:,}")
# Count parameters excluding head
body_params = sum(p.numel() for p in model_parameters(model, exclude_head=True))
print(f"Body parameters: {body_params:,}")from timm.models import adapt_input_conv, load_checkpoint
# Adapt model for different input channels (e.g., grayscale)
model = timm.create_model('resnet50', pretrained=True)
model = adapt_input_conv(model, in_chans=1)
# Load custom checkpoint
load_checkpoint(model, 'path/to/checkpoint.pth')
# Resume training
checkpoint_info = resume_checkpoint(
model,
'path/to/checkpoint.pth',
optimizer=optimizer,
log_info=True
)
start_epoch = checkpoint_info['epoch']# Create model with specific feature configuration
model = timm.create_model(
'resnet50',
pretrained=True,
features_only=True,
out_indices=[1, 2, 3, 4], # Which stages to output
output_stride=16, # Overall output stride
global_pool='', # Disable global pooling
num_classes=0 # Remove classifier
)
# Get feature info
feature_info = model.feature_info.get_dicts()
for info in feature_info:
print(f"Layer: {info['module']}, Channels: {info['num_chs']}, Reduction: {info['reduction']}")from typing import Optional, Union, List, Dict, Callable, Any, Tuple, Iterator
import torch
# Feature extraction types
FeatureDict = Dict[str, torch.Tensor]
FeatureList = List[torch.Tensor]
NodeSpec = Union[Dict[str, str], List[str]]
# Model analysis types
ParameterIterator = Iterator[torch.nn.Parameter]
ModuleDict = Dict[str, torch.nn.Module]
ParameterDict = Dict[str, torch.nn.Parameter]
# State dict types
StateDict = Dict[str, Any]
RemapDict = Dict[str, str]
# Hook types
HookFunction = Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], None]
FilterFunction = Callable[[str, torch.nn.Parameter], bool]Install with Tessl CLI
npx tessl i tessl/pypi-timm