CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-timm

PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks

Overview
Eval results
Files

utils.mddocs/

Utilities and Helpers

General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.

Capabilities

Model Utilities

Functions for model management, parameter manipulation, and model state operations.

def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
    """
    Unwrap model from DDP/EMA/other wrappers.

    Args:
        model: Wrapped model instance

    Returns:
        Unwrapped base model
    """

def get_state_dict(
    model: torch.nn.Module,
    unwrap_fn: Callable = unwrap_model
) -> Dict[str, Any]:
    """
    Get model state dictionary with unwrapping.

    Args:
        model: Model to get state dict from
        unwrap_fn: Function to unwrap model

    Returns:
        Model state dictionary
    """

def freeze(model: torch.nn.Module) -> None:
    """
    Freeze all model parameters (disable gradients).

    Args:
        model: Model to freeze
    """

def unfreeze(model: torch.nn.Module) -> None:
    """
    Unfreeze all model parameters (enable gradients).

    Args:
        model: Model to unfreeze
    """

def reparameterize_model(
    model: torch.nn.Module,
    **kwargs
) -> torch.nn.Module:
    """
    Reparameterize model for inference optimization.

    Args:
        model: Model to reparameterize
        **kwargs: Reparameterization options

    Returns:
        Reparameterized model
    """

Distributed Training Utilities

Functions for initializing and managing distributed training across multiple devices and nodes.

def init_distributed_device(args) -> Tuple[torch.device, int]:
    """
    Initialize distributed training device and process rank.

    Args:
        args: Arguments namespace with distributed training configuration

    Returns:
        Tuple of (device, world_size) for distributed training setup
    """

def distribute_bn(
    model: torch.nn.Module,
    world_size: int,
    reduce: bool = False
) -> None:
    """
    Distribute batch normalization statistics across processes.

    Args:
        model: Model with batch norm layers
        world_size: Number of distributed processes
        reduce: Reduce statistics across processes
    """

def reduce_tensor(
    tensor: torch.Tensor,
    world_size: int = 1
) -> torch.Tensor:
    """
    Reduce tensor across distributed processes.

    Args:
        tensor: Tensor to reduce
        world_size: Number of processes

    Returns:
        Reduced tensor
    """

def world_info_from_env() -> Tuple[int, int, int]:
    """
    Get distributed world info from environment variables.

    Returns:
        Tuple of (local_rank, world_rank, world_size)
    """

def is_distributed_env() -> bool:
    """
    Check if running in distributed environment.

    Returns:
        True if distributed environment detected
    """

Mixed Precision Training

Utilities for managing mixed precision training with automatic mixed precision (AMP).

class ApexScaler:
    """
    Gradient scaler using NVIDIA Apex.

    Args:
        loss_scale: Initial loss scaling factor
        init_scale: Initial scale value
        scale_factor: Scale adjustment factor
        scale_window: Scale adjustment window
    """
    
    def __init__(
        self,
        loss_scale: str = 'dynamic',
        init_scale: float = 2.**16,
        scale_factor: float = 2.0,
        scale_window: int = 2000
    ): ...
    
    def scale_loss(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer): ...
    def unscale_grads(self, optimizer: torch.optim.Optimizer): ...
    def update_scale(self, overflow: bool): ...

class NativeScaler:
    """
    Native PyTorch gradient scaler for mixed precision.

    Args:
        enabled: Enable gradient scaling
        init_scale: Initial scaling factor
        growth_factor: Scale growth factor
        backoff_factor: Scale backoff factor
        growth_interval: Interval for scale growth
    """
    
    def __init__(
        self,
        enabled: bool = True,
        init_scale: float = 2.**16,
        growth_factor: float = 2.0,
        backoff_factor: float = 0.5,
        growth_interval: int = 2000
    ): ...
    
    def scale(self, loss: torch.Tensor) -> torch.Tensor: ...
    def step(self, optimizer: torch.optim.Optimizer) -> None: ...
    def update(self) -> None: ...

CUDA and Performance Utilities

Functions for managing CUDA operations, JIT compilation, and performance optimization.

def set_jit_legacy(enable: bool) -> None:
    """
    Set legacy JIT mode.

    Args:
        enable: Enable legacy JIT mode
    """

def set_jit_fuser(fuser_name: str) -> None:
    """
    Set JIT fuser type.

    Args:
        fuser_name: Name of fuser ('te', 'old', 'nvfuser')
    """

def random_seed(seed: int, rank: int = 0) -> None:
    """
    Set random seed for reproducibility across all libraries.

    Args:
        seed: Random seed value
        rank: Process rank for distributed training
    """

Logging and Configuration

Utilities for setting up logging, argument parsing, and experiment configuration.

def setup_default_logging(
    default_level: int = logging.INFO,
    log_path: str = '',
    **kwargs
) -> None:
    """
    Setup default logging configuration.

    Args:
        default_level: Default logging level
        log_path: Path for log file
        **kwargs: Additional logging configuration
    """

def natural_key(string_: str) -> List[Union[int, str]]:
    """
    Natural sorting key function for strings with numbers.

    Args:
        string_: String to create key for

    Returns:
        List of components for natural sorting
    """

def add_bool_arg(
    parser,
    name: str,
    default: bool = False,
    help: str = ''
) -> None:
    """
    Add boolean argument to argument parser with --name/--no-name pattern.

    Args:
        parser: ArgumentParser instance
        name: Argument name
        default: Default value
        help: Help text
    """

Training Summary and Output

Functions for managing training outputs, experiment directories, and result summaries.

def update_summary(
    epoch: int,
    train_metrics: Dict[str, float],
    eval_metrics: Dict[str, float],
    filename: str,
    lr: float = None,
    write_header: bool = False,
    log_wandb: bool = False
) -> None:
    """
    Update training summary with metrics.

    Args:
        epoch: Current epoch
        train_metrics: Training metrics dictionary
        eval_metrics: Evaluation metrics dictionary
        filename: Summary file path
        lr: Current learning rate
        write_header: Write CSV header
        log_wandb: Log to Weights & Biases
    """

def get_outdir(path: str, *paths: str, inc: bool = False) -> str:
    """
    Get output directory for experiments.

    Args:
        path: Base output path
        *paths: Additional path components
        inc: Auto-increment directory name

    Returns:
        Output directory path
    """

Training Monitoring Classes

Metrics Tracking

class AverageMeter:
    """
    Computes and stores the average and current value for metrics tracking.

    Args:
        name: Name of the metric
        fmt: Format string for display
    """
    
    def __init__(self, name: str = '', fmt: str = ':f'): ...
    
    def reset(self) -> None:
        """Reset all statistics to initial values."""
    
    def update(self, val: float, n: int = 1) -> None:
        """
        Update meter with new value.

        Args:
            val: New value to add
            n: Number of samples the value represents
        """
    
    def __str__(self) -> str:
        """String representation of current meter state."""

def accuracy(
    output: torch.Tensor,
    target: torch.Tensor,
    topk: Tuple[int, ...] = (1,)
) -> List[torch.Tensor]:
    """
    Compute accuracy for specified top-k values.

    Args:
        output: Model output predictions [batch_size, num_classes]
        target: Ground truth labels [batch_size]
        topk: Tuple of k values for top-k accuracy

    Returns:
        List of accuracy tensors for each k value
    """

Model EMA Management

class ModelEma:
    """
    Model Exponential Moving Average for maintaining shadow weights.

    Args:
        model: Model to track with EMA
        decay: EMA decay rate (default: 0.9999)
        device: Device to store EMA parameters
        resume: Path to resume EMA from checkpoint
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        decay: float = 0.9999,
        device: torch.device = None,
        resume: str = ''
    ): ...
    
    def update(self, model: torch.nn.Module) -> None:
        """
        Update EMA parameters from model.

        Args:
            model: Source model for updates
        """
    
    def set(self, model: torch.nn.Module) -> None:
        """
        Set EMA parameters from model (copy all parameters).

        Args:
            model: Source model to copy from
        """

class ModelEmaV2:
    """
    Model EMA v2 with improved decay adjustment based on training progress.

    Args:
        model: Model to track
        decay: Base decay rate
        decay_type: Type of decay adjustment ('exponential', 'linear')
        device: Device for EMA parameters
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        decay: float = 0.9999,
        decay_type: str = 'exponential',
        device: torch.device = None
    ): ...

class ModelEmaV3:
    """
    Model EMA v3 with performance optimizations and memory efficiency.

    Args:
        model: Model to track
        decay: EMA decay rate
        update_after_step: Steps before starting EMA updates
        use_ema_warmup: Use warmup for EMA updates
        inv_gamma: Inverse gamma for warmup
        power: Power for warmup
        min_value: Minimum decay value
        device: Device for parameters
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        decay: float = 0.9999,
        update_after_step: int = 100,
        use_ema_warmup: bool = False,
        inv_gamma: float = 1.0,
        power: float = 2/3,
        min_value: float = 0.0,
        device: torch.device = None
    ): ...

Checkpoint Management

class CheckpointSaver:
    """
    Saves model checkpoints with configurable retention and recovery policies.

    Args:
        model: Model to save
        optimizer: Optimizer state to save
        args: Training arguments/configuration
        model_ema: EMA model to save
        amp_scaler: Mixed precision scaler
        checkpoint_prefix: Prefix for checkpoint filenames
        recovery_prefix: Prefix for recovery checkpoints
        checkpoint_dir: Directory for regular checkpoints
        recovery_dir: Directory for recovery checkpoints
        decreasing: Whether monitored metric is decreasing (lower is better)
        max_history: Maximum number of checkpoints to keep
        unwrap_fn: Function to unwrap model before saving
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        args = None,
        model_ema: ModelEma = None,
        amp_scaler = None,
        checkpoint_prefix: str = 'checkpoint',
        recovery_prefix: str = 'recovery',
        checkpoint_dir: str = '',
        recovery_dir: str = '',
        decreasing: bool = False,
        max_history: int = 10,
        unwrap_fn: Callable = unwrap_model
    ): ...
    
    def save_checkpoint(
        self,
        epoch: int,
        metric: float = None
    ) -> Tuple[str, bool]:
        """
        Save checkpoint if metric improved.

        Args:
            epoch: Current epoch number
            metric: Metric value for comparison

        Returns:
            Tuple of (checkpoint_path, is_best)
        """
    
    def save_recovery(
        self,
        epoch: int,
        batch_idx: int = 0
    ) -> str:
        """
        Save recovery checkpoint for resuming interrupted training.

        Args:
            epoch: Current epoch
            batch_idx: Current batch index

        Returns:
            Path to saved recovery checkpoint
        """

Usage Examples

Basic Training Setup with Utilities

import logging
import timm
from timm.utils import (
    setup_default_logging, random_seed, ModelEma, 
    CheckpointSaver, AverageMeter, accuracy
)

# Setup logging
setup_default_logging(log_path='training.log')
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
random_seed(42, rank=0)

# Create model and training components
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Setup EMA tracking
model_ema = ModelEma(model, decay=0.9999)

# Setup checkpoint saving
saver = CheckpointSaver(
    model=model,
    optimizer=optimizer,
    model_ema=model_ema,
    checkpoint_dir='./checkpoints',
    max_history=5,
    decreasing=False  # Higher accuracy is better
)

# Setup metrics tracking
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')

Distributed Training Setup

from timm.utils import (
    init_distributed_device, distribute_bn, reduce_tensor,
    is_distributed_env
)

# Initialize distributed training
device, world_size = init_distributed_device(args)
model = model.to(device)

if is_distributed_env():
    # Synchronize batch norm statistics
    distribute_bn(model, world_size, reduce=True)
    
    # Wrap model for distributed training
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[device], find_unused_parameters=False
    )

# In training loop - reduce metrics across processes
def train_epoch(model, loader, optimizer, device, world_size):
    losses = AverageMeter('Loss')
    
    for batch_idx, (input, target) in enumerate(loader):
        input, target = input.to(device), target.to(device)
        
        output = model(input)
        loss = criterion(output, target)
        
        # Backward and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Reduce loss across processes
        if world_size > 1:
            loss = reduce_tensor(loss, world_size)
        
        losses.update(loss.item(), input.size(0))
    
    return losses.avg

Mixed Precision Training

from timm.utils import NativeScaler

# Setup mixed precision training
scaler = NativeScaler()
model = model.to(device)

def train_step(model, input, target, optimizer, scaler):
    optimizer.zero_grad()
    
    # Forward pass with autocast
    with torch.cuda.amp.autocast():
        output = model(input)
        loss = criterion(output, target)
    
    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    return loss.item()

Complete Training Loop with Utilities

def train_model():
    setup_default_logging()
    random_seed(42)
    
    # Model setup
    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    
    # Training utilities
    model_ema = ModelEmaV2(model, decay=0.9999)
    scaler = NativeScaler()
    saver = CheckpointSaver(
        model, optimizer, model_ema=model_ema, amp_scaler=scaler,
        checkpoint_dir='./checkpoints'
    )
    
    # Metrics
    train_losses = AverageMeter('Train Loss')
    train_acc1 = AverageMeter('Train Acc@1')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_losses.reset()
        train_acc1.reset()
        
        for batch_idx, (input, target) in enumerate(train_loader):
            input, target = input.to(device), target.to(device)
            
            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                output = model(input)
                loss = criterion(output, target)
            
            # Backward pass
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # Update EMA
            model_ema.update(model)
            
            # Metrics
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            train_losses.update(loss.item(), input.size(0))
            train_acc1.update(acc1.item(), input.size(0))
        
        # Validation and checkpointing
        val_acc = validate(model_ema.module, val_loader)
        saver.save_checkpoint(epoch, val_acc)
        
        logger.info(f'Epoch {epoch}: Train Loss {train_losses.avg:.4f}, '
                   f'Train Acc {train_acc1.avg:.2f}%, Val Acc {val_acc:.2f}%')

Types

from typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch
import logging

# Device and distributed types
DeviceType = torch.device
WorldInfo = Tuple[int, int, int]  # (local_rank, world_rank, world_size)

# Metrics types
MetricValue = Union[float, torch.Tensor]
MetricDict = Dict[str, MetricValue]

# Checkpoint types
CheckpointDict = Dict[str, Any]
UnwrapFunction = Callable[[torch.nn.Module], torch.nn.Module]

# Scaler types
LossScaler = Union[torch.cuda.amp.GradScaler, Any]

# Logging types
LogLevel = int
Logger = logging.Logger

# Utility function types
SeedFunction = Callable[[int, int], None]
ReduceFunction = Callable[[torch.Tensor, int], torch.Tensor]

Install with Tessl CLI

npx tessl i tessl/pypi-timm

docs

data.md

features.md

index.md

layers.md

models.md

training.md

utils.md

tile.json