CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-timm

PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks

Overview
Eval results
Files

models.mddocs/

Model Creation and Management

Comprehensive functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of 1000+ pretrained models across 90+ architectures.

Capabilities

Model Creation

Create model instances with extensive configuration options, including pretrained weights, custom number of classes, and architectural modifications.

def create_model(
    model_name: str,
    pretrained: bool = False,
    pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
    pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
    checkpoint_path: Optional[Union[str, Path]] = None,
    cache_dir: Optional[Union[str, Path]] = None,
    scriptable: Optional[bool] = None,
    exportable: Optional[bool] = None,
    no_jit: Optional[bool] = None,
    **kwargs: Any
) -> torch.nn.Module:
    """
    Create a model instance.

    Args:
        model_name: Name of model to instantiate
        pretrained: Load pretrained weights if True
        pretrained_cfg: Pretrained configuration override (dict or cfg name)
        pretrained_cfg_overlay: Dictionary of config overrides
        num_classes: Number of output classes (default: 1000)
        in_chans: Number of input image channels (default: 3)
        global_pool: Global pooling type override
        scriptable: Set layer config so model is jit scriptable
        exportable: Set layer config so model is traceable/ONNX exportable
        no_jit: Disable jit related set/reset of layer config
        checkpoint_path: Path to load checkpoint from instead of pretrained weights
        cache_dir: Cache directory for downloaded pretrained weights
        **kwargs: Model-specific arguments

    Returns:
        Instantiated model
    """

Usage Examples

import timm

# Basic model creation
model = timm.create_model('resnet50', pretrained=True)

# Custom number of classes for fine-tuning
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)

# Model for feature extraction
feature_model = timm.create_model('vit_base_patch16_224', pretrained=True, features_only=True)

# Model optimized for export
export_model = timm.create_model('resnet18', pretrained=True, scriptable=True, exportable=True)

# Load from custom checkpoint
model = timm.create_model('resnet50', checkpoint_path='/path/to/checkpoint.pth')

Model Discovery

Functions to explore and filter the available model architectures and pretrained weights.

def list_models(
    filter: str = '',
    module: str = '',
    pretrained: bool = False,
    exclude_filters: str = '',
    name_matches_cfg: bool = False,
    include_tags: bool = True
) -> list[str]:
    """
    List available models.

    Args:
        filter: Wildcard filter string to limit model names
        module: Specific module/architecture to limit results
        pretrained: Only models with pretrained weights if True
        exclude_filters: Exclude models matching these patterns
        name_matches_cfg: Only models where name matches config
        include_tags: Include model tags in results

    Returns:
        List of model names matching criteria
    """

def list_pretrained(filter: str = '') -> list[str]:
    """
    List models with pretrained weights available.

    Args:
        filter: Wildcard filter for model names

    Returns:
        List of model names with pretrained weights
    """

def list_modules() -> list[str]:
    """
    List available model modules/architectures.

    Returns:
        List of module names
    """

Usage Examples

# List all models
all_models = timm.list_models()

# Filter models by architecture
resnet_models = timm.list_models('*resnet*')
vit_models = timm.list_models('vit_*')

# Only models with pretrained weights
pretrained_models = timm.list_models(pretrained=True)

# List specific architecture variants
efficientnet_pretrained = timm.list_pretrained('efficientnet*')

# Available model families
architectures = timm.list_modules()

Model Validation

Utilities to validate model names and check availability of pretrained weights.

def is_model(model_name: str) -> bool:
    """
    Check if model name is valid and available.

    Args:
        model_name: Name to check

    Returns:
        True if model exists, False otherwise
    """

def is_model_pretrained(model_name: str) -> bool:
    """
    Check if model has pretrained weights available.

    Args:
        model_name: Model name to check

    Returns:
        True if pretrained weights exist, False otherwise
    """

def model_entrypoint(model_name: str) -> Callable:
    """
    Get the entrypoint function for a model.

    Args:
        model_name: Name of model

    Returns:
        Model creation function
    """

Model Configuration

Access and retrieve model configuration and metadata.

def get_pretrained_cfg(model_name: str) -> dict:
    """
    Get pretrained configuration for model.

    Args:
        model_name: Name of model

    Returns:
        Dictionary containing model configuration including:
        - input_size: Expected input dimensions
        - mean: Normalization mean values
        - std: Normalization standard deviation values
        - num_classes: Number of output classes
        - pool_size: Global pooling output size
        - crop_pct: Center crop percentage
        - interpolation: Resize interpolation method
        - first_conv: Name of first convolutional layer
        - classifier: Name of classifier layer
    """

def get_pretrained_cfg_value(model_name: str, cfg_key: str):
    """
    Get specific configuration value for pretrained model.

    Args:
        model_name: Name of model
        cfg_key: Configuration key to retrieve

    Returns:
        Configuration value for specified key
    """

Usage Examples

# Get complete model configuration
cfg = timm.get_pretrained_cfg('resnet50')
print(f"Input size: {cfg['input_size']}")
print(f"Mean: {cfg['mean']}")
print(f"Std: {cfg['std']}")

# Get specific configuration values
input_size = timm.get_pretrained_cfg_value('efficientnet_b0', 'input_size')
crop_pct = timm.get_pretrained_cfg_value('vit_base_patch16_224', 'crop_pct')

# Validate model availability
if timm.is_model('my_custom_model'):
    model = timm.create_model('my_custom_model')

# Check for pretrained weights
if timm.is_model_pretrained('resnet101'):
    model = timm.create_model('resnet101', pretrained=True)

Advanced Model Creation

Advanced patterns for model customization and creation.

Model Factory Functions

def create_model_from_pretrained(
    model_name: str,
    pretrained_cfg: dict = None,
    **model_kwargs
) -> torch.nn.Module:
    """
    Create model using specific pretrained configuration.

    Args:
        model_name: Name of model to create
        pretrained_cfg: Custom pretrained configuration
        **model_kwargs: Additional model arguments

    Returns:
        Configured model instance
    """

Custom Model Registration

def register_model(fn: Callable = None, *, name: str = None) -> Callable:
    """
    Register a new model architecture.

    Args:
        fn: Model creation function
        name: Optional model name override

    Returns:
        Decorated function
    """

Usage Examples

# Register custom model
@timm.register_model
def my_custom_resnet(pretrained=False, **kwargs):
    # Custom ResNet implementation
    model = MyCustomResNet(**kwargs)
    if pretrained:
        # Load custom pretrained weights
        pass
    return model

# Use registered model
custom_model = timm.create_model('my_custom_resnet', pretrained=True)

Hugging Face Hub Integration

TIMM provides seamless integration with Hugging Face Hub for loading models and configurations.

def load_model_config_from_hf(model_id: str) -> dict:
    """
    Load model configuration from Hugging Face Hub.

    Args:
        model_id: Hugging Face model identifier

    Returns:
        Model configuration dictionary
    """

def load_state_dict_from_hf(model_id: str) -> dict:
    """
    Load model weights from Hugging Face Hub.

    Args:
        model_id: Hugging Face model identifier

    Returns:
        Model state dictionary
    """

Hub Model Loading Examples

# Load model from Hugging Face Hub using hf-hub: prefix
model = timm.create_model('hf-hub:microsoft/resnet-50', pretrained=True)

# Load local model using local-dir: prefix
model = timm.create_model('local-dir:/path/to/model/folder', pretrained=True)

# Load specific model revision/branch
model = timm.create_model('hf-hub:microsoft/resnet-50@main', pretrained=True)

Model Architecture Categories

TIMM includes models from the following major categories:

Vision Transformers

  • ViT: Vision Transformer variants (Base, Large, Huge)
  • DeiT: Data-efficient Image Transformers
  • BEiT: Bidirectional Encoder representation from Image Transformers
  • Swin: Swin Transformer hierarchical models
  • CaiT: Class-Attention in Image Transformers
  • CrossViT: Cross-Attention Multi-Scale Vision Transformer

Convolutional Networks

  • ResNet: ResNet and ResNeXt families
  • EfficientNet: EfficientNet B0-B8 and V2 variants
  • ConvNeXt: Modern ConvNet architectures
  • RegNet: Designing Network Design Spaces
  • DenseNet: Densely Connected Convolutional Networks
  • MobileNet: MobileNetV3 and variants

Hybrid Architectures

  • ConViT: Convolutions meet Vision Transformers
  • LeViT: Vision Transformer in ConvNet's Clothing
  • CoAtNet: Convolution and Attention networks
  • MaxViT: Multi-Axis Vision Transformer

Specialized Models

  • CLIP: Vision encoders from CLIP models
  • BEiT3: Multimodal foundation models
  • EVA: Enhanced Vision Transformer
  • InternViT: Large-scale vision foundation models

Advanced Features

NaFlexViT (Native Flexible Vision Transformers)

TIMM supports variable aspect ratio and resolution training/inference through NaFlexViT integration.

# Enable NaFlexViT for supported models
model = timm.create_model('vit_base_patch16_224', pretrained=True, use_naflex=True)

# Models with ROPE support can be loaded in NaFlexViT mode
model = timm.create_model('eva_large_patch14_196', pretrained=True, use_naflex=True)

Forward Intermediates API

Extract intermediate features from models during forward pass.

# Enable intermediate feature extraction
model = timm.create_model('resnet50', pretrained=True)
features = model.forward_intermediates(x, indices=[1, 2, 3, 4])

Types

from typing import Optional, Union, List, Dict, Callable, Any
import torch

# Model configuration types
PretrainedCfg = Dict[str, Any]
ModelCfg = Dict[str, Any]

# Model creation function signature
ModelEntrypoint = Callable[..., torch.nn.Module]

# Filter types for model listing
ModelFilter = Union[str, List[str]]

Install with Tessl CLI

npx tessl i tessl/pypi-timm

docs

data.md

features.md

index.md

layers.md

models.md

training.md

utils.md

tile.json