PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
Comprehensive functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of 1000+ pretrained models across 90+ architectures.
Create model instances with extensive configuration options, including pretrained weights, custom number of classes, and architectural modifications.
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs: Any
) -> torch.nn.Module:
"""
Create a model instance.
Args:
model_name: Name of model to instantiate
pretrained: Load pretrained weights if True
pretrained_cfg: Pretrained configuration override (dict or cfg name)
pretrained_cfg_overlay: Dictionary of config overrides
num_classes: Number of output classes (default: 1000)
in_chans: Number of input image channels (default: 3)
global_pool: Global pooling type override
scriptable: Set layer config so model is jit scriptable
exportable: Set layer config so model is traceable/ONNX exportable
no_jit: Disable jit related set/reset of layer config
checkpoint_path: Path to load checkpoint from instead of pretrained weights
cache_dir: Cache directory for downloaded pretrained weights
**kwargs: Model-specific arguments
Returns:
Instantiated model
"""import timm
# Basic model creation
model = timm.create_model('resnet50', pretrained=True)
# Custom number of classes for fine-tuning
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
# Model for feature extraction
feature_model = timm.create_model('vit_base_patch16_224', pretrained=True, features_only=True)
# Model optimized for export
export_model = timm.create_model('resnet18', pretrained=True, scriptable=True, exportable=True)
# Load from custom checkpoint
model = timm.create_model('resnet50', checkpoint_path='/path/to/checkpoint.pth')Functions to explore and filter the available model architectures and pretrained weights.
def list_models(
filter: str = '',
module: str = '',
pretrained: bool = False,
exclude_filters: str = '',
name_matches_cfg: bool = False,
include_tags: bool = True
) -> list[str]:
"""
List available models.
Args:
filter: Wildcard filter string to limit model names
module: Specific module/architecture to limit results
pretrained: Only models with pretrained weights if True
exclude_filters: Exclude models matching these patterns
name_matches_cfg: Only models where name matches config
include_tags: Include model tags in results
Returns:
List of model names matching criteria
"""
def list_pretrained(filter: str = '') -> list[str]:
"""
List models with pretrained weights available.
Args:
filter: Wildcard filter for model names
Returns:
List of model names with pretrained weights
"""
def list_modules() -> list[str]:
"""
List available model modules/architectures.
Returns:
List of module names
"""# List all models
all_models = timm.list_models()
# Filter models by architecture
resnet_models = timm.list_models('*resnet*')
vit_models = timm.list_models('vit_*')
# Only models with pretrained weights
pretrained_models = timm.list_models(pretrained=True)
# List specific architecture variants
efficientnet_pretrained = timm.list_pretrained('efficientnet*')
# Available model families
architectures = timm.list_modules()Utilities to validate model names and check availability of pretrained weights.
def is_model(model_name: str) -> bool:
"""
Check if model name is valid and available.
Args:
model_name: Name to check
Returns:
True if model exists, False otherwise
"""
def is_model_pretrained(model_name: str) -> bool:
"""
Check if model has pretrained weights available.
Args:
model_name: Model name to check
Returns:
True if pretrained weights exist, False otherwise
"""
def model_entrypoint(model_name: str) -> Callable:
"""
Get the entrypoint function for a model.
Args:
model_name: Name of model
Returns:
Model creation function
"""Access and retrieve model configuration and metadata.
def get_pretrained_cfg(model_name: str) -> dict:
"""
Get pretrained configuration for model.
Args:
model_name: Name of model
Returns:
Dictionary containing model configuration including:
- input_size: Expected input dimensions
- mean: Normalization mean values
- std: Normalization standard deviation values
- num_classes: Number of output classes
- pool_size: Global pooling output size
- crop_pct: Center crop percentage
- interpolation: Resize interpolation method
- first_conv: Name of first convolutional layer
- classifier: Name of classifier layer
"""
def get_pretrained_cfg_value(model_name: str, cfg_key: str):
"""
Get specific configuration value for pretrained model.
Args:
model_name: Name of model
cfg_key: Configuration key to retrieve
Returns:
Configuration value for specified key
"""# Get complete model configuration
cfg = timm.get_pretrained_cfg('resnet50')
print(f"Input size: {cfg['input_size']}")
print(f"Mean: {cfg['mean']}")
print(f"Std: {cfg['std']}")
# Get specific configuration values
input_size = timm.get_pretrained_cfg_value('efficientnet_b0', 'input_size')
crop_pct = timm.get_pretrained_cfg_value('vit_base_patch16_224', 'crop_pct')
# Validate model availability
if timm.is_model('my_custom_model'):
model = timm.create_model('my_custom_model')
# Check for pretrained weights
if timm.is_model_pretrained('resnet101'):
model = timm.create_model('resnet101', pretrained=True)Advanced patterns for model customization and creation.
def create_model_from_pretrained(
model_name: str,
pretrained_cfg: dict = None,
**model_kwargs
) -> torch.nn.Module:
"""
Create model using specific pretrained configuration.
Args:
model_name: Name of model to create
pretrained_cfg: Custom pretrained configuration
**model_kwargs: Additional model arguments
Returns:
Configured model instance
"""def register_model(fn: Callable = None, *, name: str = None) -> Callable:
"""
Register a new model architecture.
Args:
fn: Model creation function
name: Optional model name override
Returns:
Decorated function
"""# Register custom model
@timm.register_model
def my_custom_resnet(pretrained=False, **kwargs):
# Custom ResNet implementation
model = MyCustomResNet(**kwargs)
if pretrained:
# Load custom pretrained weights
pass
return model
# Use registered model
custom_model = timm.create_model('my_custom_resnet', pretrained=True)TIMM provides seamless integration with Hugging Face Hub for loading models and configurations.
def load_model_config_from_hf(model_id: str) -> dict:
"""
Load model configuration from Hugging Face Hub.
Args:
model_id: Hugging Face model identifier
Returns:
Model configuration dictionary
"""
def load_state_dict_from_hf(model_id: str) -> dict:
"""
Load model weights from Hugging Face Hub.
Args:
model_id: Hugging Face model identifier
Returns:
Model state dictionary
"""# Load model from Hugging Face Hub using hf-hub: prefix
model = timm.create_model('hf-hub:microsoft/resnet-50', pretrained=True)
# Load local model using local-dir: prefix
model = timm.create_model('local-dir:/path/to/model/folder', pretrained=True)
# Load specific model revision/branch
model = timm.create_model('hf-hub:microsoft/resnet-50@main', pretrained=True)TIMM includes models from the following major categories:
TIMM supports variable aspect ratio and resolution training/inference through NaFlexViT integration.
# Enable NaFlexViT for supported models
model = timm.create_model('vit_base_patch16_224', pretrained=True, use_naflex=True)
# Models with ROPE support can be loaded in NaFlexViT mode
model = timm.create_model('eva_large_patch14_196', pretrained=True, use_naflex=True)Extract intermediate features from models during forward pass.
# Enable intermediate feature extraction
model = timm.create_model('resnet50', pretrained=True)
features = model.forward_intermediates(x, indices=[1, 2, 3, 4])from typing import Optional, Union, List, Dict, Callable, Any
import torch
# Model configuration types
PretrainedCfg = Dict[str, Any]
ModelCfg = Dict[str, Any]
# Model creation function signature
ModelEntrypoint = Callable[..., torch.nn.Module]
# Filter types for model listing
ModelFilter = Union[str, List[str]]Install with Tessl CLI
npx tessl i tessl/pypi-timm