CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-timm

PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks

Overview
Eval results
Files

training.mddocs/

Training Infrastructure

Comprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.

Capabilities

Optimizer Creation

Factory functions for creating optimizers with advanced configurations and parameter grouping strategies.

def create_optimizer_v2(
    model_or_params: Union[torch.nn.Module, ParamsT],
    opt: str = 'sgd',
    lr: Optional[float] = None,
    weight_decay: float = 0.0,
    momentum: float = 0.9,
    foreach: Optional[bool] = None,
    filter_bias_and_bn: bool = True,
    layer_decay: Optional[float] = None,
    layer_decay_min_scale: float = 0.0,
    layer_decay_no_opt_scale: Optional[float] = None,
    param_group_fn: Optional[Callable[[torch.nn.Module], ParamsT]] = None,
    **kwargs: Any
) -> torch.optim.Optimizer:
    """
    Create optimizer with v2 interface.

    Args:
        model_or_params: Model instance or parameter groups
        opt: Optimizer name ('sgd', 'adam', 'adamw', 'rmsprop', etc.)
        lr: Learning rate
        weight_decay: Weight decay coefficient
        momentum: Momentum coefficient (for SGD)
        eps: Epsilon for numerical stability
        betas: Beta coefficients for Adam-family optimizers
        opt_args: Additional optimizer arguments
        **kwargs: Additional arguments

    Returns:
        Configured optimizer instance
    """

def create_optimizer(
    args,
    model: torch.nn.Module,
    filter_bias_and_bn: bool = True
):
    """
    Create optimizer from arguments (legacy interface).

    Args:
        args: Arguments namespace with optimizer configuration
        model: Model to optimize
        filter_bias_and_bn: Filter bias and batch norm parameters

    Returns:
        Configured optimizer
    """

def list_optimizers() -> List[str]:
    """
    List available optimizer names.

    Returns:
        List of supported optimizer names
    """

def get_optimizer_class(optimizer_name: str):
    """
    Get optimizer class by name.

    Args:
        optimizer_name: Name of optimizer

    Returns:
        Optimizer class
    """

Parameter Grouping

Functions for creating parameter groups with different learning rates, weight decay, and layer-specific configurations.

def param_groups_layer_decay(
    model: torch.nn.Module,
    weight_decay: float = 0.05,
    no_weight_decay_list: List[str] = None,
    layer_decay: float = 0.75,
    end_lr_scale: float = 1.0
) -> List[dict]:
    """
    Create parameter groups with layer-wise learning rate decay.

    Args:
        model: Model to create parameter groups for
        weight_decay: Base weight decay rate
        no_weight_decay_list: Parameters to exclude from weight decay
        layer_decay: Layer decay factor
        end_lr_scale: Learning rate scale for final layer

    Returns:
        List of parameter group dictionaries
    """

def param_groups_weight_decay(
    model: torch.nn.Module,
    weight_decay: float = 1e-5,
    no_weight_decay_list: List[str] = None
) -> List[dict]:
    """
    Create parameter groups with selective weight decay.

    Args:
        model: Model to create parameter groups for
        weight_decay: Weight decay rate
        no_weight_decay_list: Parameters to exclude from weight decay

    Returns:
        List of parameter group dictionaries
    """

Optimizer Classes

Custom Optimizers

class AdaBelief(torch.optim.Optimizer):
    """
    AdaBelief optimizer.

    Args:
        params: Iterable of parameters
        lr: Learning rate
        betas: Beta coefficients
        eps: Epsilon for numerical stability
        weight_decay: Weight decay coefficient
        amsgrad: Use AMSGrad variant
        weight_decouple: Decouple weight decay
        fixed_decay: Use fixed decay
        rectify: Use rectification
    """
    
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: tuple = (0.9, 0.999),
        eps: float = 1e-16,
        weight_decay: float = 0,
        amsgrad: bool = False,
        weight_decouple: bool = True,
        fixed_decay: bool = False,
        rectify: bool = True
    ): ...

class Lamb(torch.optim.Optimizer):
    """
    LAMB (Layer-wise Adaptive Moments) optimizer.

    Args:
        params: Iterable of parameters
        lr: Learning rate
        betas: Beta coefficients
        eps: Epsilon for numerical stability
        weight_decay: Weight decay coefficient
        grad_averaging: Use gradient averaging
        max_grad_norm: Maximum gradient norm
        trust_clip: Trust region clipping
        always_adapt: Always adapt learning rate
    """
    
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: tuple = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.01,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        trust_clip: bool = False,
        always_adapt: bool = False
    ): ...

class Lion(torch.optim.Optimizer):
    """
    Lion (EvoLved Sign Momentum) optimizer.

    Args:
        params: Iterable of parameters
        lr: Learning rate
        betas: Beta coefficients for momentum
        weight_decay: Weight decay coefficient
        use_triton: Use Triton kernel implementation
    """
    
    def __init__(
        self,
        params,
        lr: float = 1e-4,
        betas: tuple = (0.9, 0.99),
        weight_decay: float = 0.0,
        use_triton: bool = False
    ): ...

class Lookahead(torch.optim.Optimizer):
    """
    Lookahead optimizer wrapper.

    Args:
        base_optimizer: Base optimizer to wrap
        alpha: Lookahead step size
        k: Lookahead frequency
        pullback_momentum: Pullback momentum mode
    """
    
    def __init__(
        self,
        base_optimizer: torch.optim.Optimizer,
        alpha: float = 0.5,
        k: int = 6,
        pullback_momentum: str = "none"
    ): ...

Learning Rate Schedulers

Scheduler Creation

def create_scheduler_v2(
    optimizer: torch.optim.Optimizer,
    sched: str = 'cosine',
    num_epochs: int = 300,
    decay_epochs: int = 90,
    decay_milestones: List[int] = (90, 180, 270),
    cooldown_epochs: int = 0,
    patience_epochs: int = 10,
    decay_rate: float = 0.1,
    min_lr: float = 0,
    warmup_lr: float = 1e-5,
    warmup_epochs: int = 0,
    warmup_prefix: bool = False,
    noise: Union[float, List[float]] = None,
    noise_pct: float = 0.67,
    noise_std: float = 1.0,
    noise_seed: int = 42,
    cycle_mul: float = 1.0,
    cycle_decay: float = 0.1,
    cycle_limit: int = 1,
    k_decay: float = 1.0,
    plateau_mode: str = 'max',
    step_on_epochs: bool = True,
    updates_per_epoch: int = 0
):
    """
    Create learning rate scheduler with v2 interface.

    Args:
        optimizer: Optimizer instance
        sched: Scheduler type ('step', 'cosine', 'tanh', 'poly', 'plateau', etc.)
        num_epochs: Total number of training epochs
        decay_epochs: Epochs between learning rate decay
        decay_rate: Learning rate decay factor
        min_lr: Minimum learning rate
        warmup_lr: Warmup initial learning rate
        warmup_epochs: Number of warmup epochs
        cooldown_epochs: Number of cooldown epochs
        patience_epochs: Patience for plateau scheduler
        cycle_mul: Cycle length multiplier
        cycle_decay: Cycle decay factor
        cycle_limit: Maximum number of cycles
        noise_range: Learning rate noise range
        noise_pct: Noise percentage
        noise_std: Noise standard deviation
        noise_seed: Random seed for noise
        k_decay: K decay factor
        plateau_mode: Plateau mode ('min' or 'max')
        step_on_epochs: Step on epochs vs iterations
        updates_per_epoch: Updates per epoch for iteration-based stepping
        **kwargs: Additional scheduler arguments

    Returns:
        Configured scheduler instance
    """

def scheduler_kwargs(args) -> dict:
    """
    Extract scheduler keyword arguments from args.

    Args:
        args: Arguments namespace

    Returns:
        Dictionary of scheduler arguments
    """

Scheduler Classes

class CosineLRScheduler:
    """
    Cosine annealing learning rate scheduler with warm restarts.

    Args:
        optimizer: Optimizer instance
        t_initial: Initial number of epochs/iterations
        lr_min: Minimum learning rate
        cycle_mul: Cycle length multiplier
        cycle_decay: Cycle amplitude decay
        cycle_limit: Maximum number of cycles
        warmup_t: Warmup iterations
        warmup_lr_init: Initial warmup learning rate
        warmup_prefix: Warmup before first cycle
        t_in_epochs: Interpret t_initial as epochs
        noise_range_t: Noise range for time
        noise_pct: Noise percentage
        noise_std: Noise standard deviation
        noise_seed: Random seed
        initialize: Initialize learning rates
    """
    
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        t_initial: int,
        lr_min: float = 0.0,
        cycle_mul: float = 1.0,
        cycle_decay: float = 1.0,
        cycle_limit: int = 1,
        warmup_t: int = 0,
        warmup_lr_init: float = 0,
        warmup_prefix: bool = False,
        t_in_epochs: bool = True,
        noise_range_t: tuple = None,
        noise_pct: float = 0.67,
        noise_std: float = 1.0,
        noise_seed: int = None,
        initialize: bool = True
    ): ...

class StepLRScheduler:
    """
    Step learning rate scheduler.

    Args:
        optimizer: Optimizer instance
        decay_t: Step intervals for decay
        decay_rate: Decay factor
        warmup_t: Warmup iterations
        warmup_lr_init: Initial warmup learning rate
        t_in_epochs: Interpret intervals as epochs
        noise_range_t: Noise range for time
        noise_pct: Noise percentage
        noise_std: Noise standard deviation
        noise_seed: Random seed
        initialize: Initialize learning rates
    """
    
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        decay_t: Union[int, List[int]],
        decay_rate: float = 0.1,
        warmup_t: int = 0,
        warmup_lr_init: float = 0,
        t_in_epochs: bool = True,
        noise_range_t: tuple = None,
        noise_pct: float = 0.67,
        noise_std: float = 1.0,
        noise_seed: int = None,
        initialize: bool = True
    ): ...

class PlateauLRScheduler:
    """
    Plateau-based learning rate scheduler.

    Args:
        optimizer: Optimizer instance
        decay_rate: Decay factor when plateau detected
        patience_t: Patience before decay
        verbose: Print decay messages
        threshold: Threshold for measuring improvement
        cooldown_t: Cooldown period after decay
        mode: Mode for plateau detection ('min' or 'max')
        lr_min: Minimum learning rate
        warmup_t: Warmup iterations
        warmup_lr_init: Initial warmup learning rate
        t_in_epochs: Interpret intervals as epochs
        noise_range_t: Noise range for time
        noise_pct: Noise percentage
        noise_std: Noise standard deviation
        noise_seed: Random seed
        initialize: Initialize learning rates
    """
    
    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        decay_rate: float = 0.1,
        patience_t: int = 10,
        verbose: bool = True,
        threshold: float = 1e-4,
        cooldown_t: int = 0,
        mode: str = 'max',
        lr_min: float = 0,
        warmup_t: int = 0,
        warmup_lr_init: float = 0,
        t_in_epochs: bool = True,
        noise_range_t: tuple = None,
        noise_pct: float = 0.67,
        noise_std: float = 1.0,
        noise_seed: int = None,
        initialize: bool = True
    ): ...

Loss Functions

Loss Classes

class LabelSmoothingCrossEntropy(torch.nn.Module):
    """
    Cross entropy loss with label smoothing.

    Args:
        smoothing: Label smoothing factor (0.0 to 1.0)
        weight: Class weights for unbalanced datasets
        reduction: Loss reduction ('mean', 'sum', 'none')
    """
    
    def __init__(
        self,
        smoothing: float = 0.1,
        weight: torch.Tensor = None,
        reduction: str = 'mean'
    ): ...

class SoftTargetCrossEntropy(torch.nn.Module):
    """
    Cross entropy loss with soft targets (for knowledge distillation).

    Args:
        weight: Class weights
        size_average: Deprecated, use reduction
        ignore_index: Index to ignore in loss computation
        reduce: Deprecated, use reduction
        reduction: Loss reduction ('mean', 'sum', 'none')
    """
    
    def __init__(
        self,
        weight: torch.Tensor = None,
        size_average: bool = None,
        ignore_index: int = -100,
        reduce: bool = None,
        reduction: str = 'mean'
    ): ...

class JsdCrossEntropy(torch.nn.Module):
    """
    Jensen-Shannon divergence cross entropy loss.

    Args:
        num_splits: Number of augmentation splits
        alpha: Mixing parameter for splits
        weight: Class weights
        size_average: Deprecated, use reduction
        ignore_index: Index to ignore
        reduce: Deprecated, use reduction
        reduction: Loss reduction
        smoothing: Label smoothing factor
    """
    
    def __init__(
        self,
        num_splits: int = 2,
        alpha: float = 12.0,
        weight: torch.Tensor = None,
        size_average: bool = None,
        ignore_index: int = -100,
        reduce: bool = None,
        reduction: str = 'mean',
        smoothing: float = 0.1
    ): ...

class BinaryCrossEntropy(torch.nn.Module):
    """
    Binary cross entropy loss with optional smoothing.

    Args:
        smoothing: Label smoothing factor
        target_threshold: Threshold for hard targets
        weight: Class weights
        reduction: Loss reduction
        pos_weight: Positive class weight
    """
    
    def __init__(
        self,
        smoothing: float = 0.0,
        target_threshold: float = None,
        weight: torch.Tensor = None,
        reduction: str = 'mean',
        pos_weight: torch.Tensor = None
    ): ...

class AsymmetricLossMultiLabel(torch.nn.Module):
    """
    Asymmetric loss for multi-label classification.

    Args:
        gamma_neg: Focusing parameter for negative examples
        gamma_pos: Focusing parameter for positive examples
        clip: Clipping value for probability
        eps: Epsilon for numerical stability
        disable_torch_grad_focal_loss: Disable gradient computation
    """
    
    def __init__(
        self,
        gamma_neg: float = 4,
        gamma_pos: float = 1,
        clip: float = 0.05,
        eps: float = 1e-8,
        disable_torch_grad_focal_loss: bool = False
    ): ...

Training Utilities

Model EMA (Exponential Moving Average)

class ModelEma:
    """
    Model Exponential Moving Average.

    Args:
        model: Model to track
        decay: EMA decay rate
        device: Device for EMA parameters
        resume: Resume from checkpoint path
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        decay: float = 0.9999,
        device: torch.device = None,
        resume: str = ''
    ): ...
    
    def update(self, model: torch.nn.Module) -> None:
        """Update EMA parameters."""
    
    def set(self, model: torch.nn.Module) -> None:
        """Set EMA parameters from model."""

class ModelEmaV2:
    """
    Model EMA v2 with improved decay adjustment.

    Args:
        model: Model to track
        decay: Base decay rate
        decay_type: Decay adjustment type
        device: Device for EMA parameters
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        decay: float = 0.9999,
        decay_type: str = 'exponential',
        device: torch.device = None
    ): ...

Gradient Utilities

def adaptive_clip_grad(
    parameters,
    clip_factor: float = 0.01,
    eps: float = 1e-3,
    norm_type: float = 2.0
) -> torch.Tensor:
    """
    Adaptive gradient clipping.

    Args:
        parameters: Model parameters
        clip_factor: Adaptive clipping factor
        eps: Epsilon for numerical stability
        norm_type: Norm type for gradient computation

    Returns:
        Gradient norm
    """

def dispatch_clip_grad(
    parameters,
    value: float,
    mode: str = 'norm',
    norm_type: float = 2.0
) -> torch.Tensor:
    """
    Dispatch gradient clipping method.

    Args:
        parameters: Model parameters
        value: Clipping value
        mode: Clipping mode ('norm', 'value', 'agc')
        norm_type: Norm type for gradient computation

    Returns:
        Gradient norm
    """

Checkpointing

class CheckpointSaver:
    """
    Model checkpoint saver with configurable retention policy.

    Args:
        model: Model to save
        optimizer: Optimizer to save
        args: Training arguments
        model_ema: EMA model to save
        amp_scaler: AMP scaler to save
        checkpoint_prefix: Checkpoint filename prefix
        recovery_prefix: Recovery checkpoint prefix
        checkpoint_dir: Directory for checkpoints
        recovery_dir: Directory for recovery checkpoints
        decreasing: Monitor decreasing metric
        max_history: Maximum checkpoint history
        unwrap_fn: Function to unwrap model
    """
    
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        args=None,
        model_ema: ModelEma = None,
        amp_scaler=None,
        checkpoint_prefix: str = 'checkpoint',
        recovery_prefix: str = 'recovery',
        checkpoint_dir: str = '',
        recovery_dir: str = '',
        decreasing: bool = False,
        max_history: int = 10,
        unwrap_fn: Callable = None
    ): ...
    
    def save_checkpoint(
        self,
        epoch: int,
        metric: float = None
    ) -> str:
        """Save checkpoint."""
    
    def save_recovery(self, epoch: int, batch_idx: int = 0) -> str:
        """Save recovery checkpoint."""

Metrics and Monitoring

class AverageMeter:
    """
    Computes and stores the average and current value.

    Args:
        name: Meter name
        fmt: Format string for display
    """
    
    def __init__(self, name: str = '', fmt: str = ':f'): ...
    
    def reset(self) -> None:
        """Reset all statistics."""
    
    def update(self, val: float, n: int = 1) -> None:
        """Update with new value."""

def accuracy(
    output: torch.Tensor,
    target: torch.Tensor,
    topk: tuple = (1,)
) -> List[torch.Tensor]:
    """
    Compute accuracy for specified top-k values.

    Args:
        output: Model predictions
        target: Ground truth labels
        topk: Top-k values to compute

    Returns:
        List of accuracy values for each k
    """

Usage Examples

Complete Training Setup

import timm
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
from timm.loss import LabelSmoothingCrossEntropy
from timm.utils import ModelEma, CheckpointSaver, AverageMeter

# Create model
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)

# Create optimizer with layer decay
optimizer = create_optimizer_v2(
    model,
    opt='adamw',
    lr=1e-3,
    weight_decay=0.05
)

# Create learning rate scheduler
scheduler = create_scheduler_v2(
    optimizer,
    sched='cosine',
    num_epochs=100,
    warmup_epochs=5,
    warmup_lr=1e-5,
    min_lr=1e-6
)

# Create loss function
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

# Create EMA
model_ema = ModelEma(model, decay=0.9999)

# Create checkpoint saver
saver = CheckpointSaver(
    model=model,
    optimizer=optimizer,
    model_ema=model_ema,
    checkpoint_dir='./checkpoints',
    max_history=5
)

# Metrics
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')

Advanced Optimizer Configuration

from timm.optim import param_groups_layer_decay, Lamb, Lookahead

# Create parameter groups with layer decay
param_groups = param_groups_layer_decay(
    model,
    weight_decay=0.05,
    layer_decay=0.8
)

# Create LAMB optimizer
base_optimizer = Lamb(param_groups, lr=1e-3)

# Wrap with Lookahead
optimizer = Lookahead(base_optimizer, alpha=0.5, k=6)

Types

from typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch

# Optimizer and scheduler types
OptimizerType = torch.optim.Optimizer
SchedulerType = torch.optim.lr_scheduler._LRScheduler

# Parameter types
ParamGroup = Dict[str, Any]
ParamGroups = List[ParamGroup]

# Loss function type
LossFunction = torch.nn.Module

# Metric types
MetricValue = Union[float, torch.Tensor]
MetricDict = Dict[str, MetricValue]

Install with Tessl CLI

npx tessl i tessl/pypi-timm

docs

data.md

features.md

index.md

layers.md

models.md

training.md

utils.md

tile.json