PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
Comprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.
Factory functions for creating optimizers with advanced configurations and parameter grouping strategies.
def create_optimizer_v2(
model_or_params: Union[torch.nn.Module, ParamsT],
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.0,
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
layer_decay_min_scale: float = 0.0,
layer_decay_no_opt_scale: Optional[float] = None,
param_group_fn: Optional[Callable[[torch.nn.Module], ParamsT]] = None,
**kwargs: Any
) -> torch.optim.Optimizer:
"""
Create optimizer with v2 interface.
Args:
model_or_params: Model instance or parameter groups
opt: Optimizer name ('sgd', 'adam', 'adamw', 'rmsprop', etc.)
lr: Learning rate
weight_decay: Weight decay coefficient
momentum: Momentum coefficient (for SGD)
eps: Epsilon for numerical stability
betas: Beta coefficients for Adam-family optimizers
opt_args: Additional optimizer arguments
**kwargs: Additional arguments
Returns:
Configured optimizer instance
"""
def create_optimizer(
args,
model: torch.nn.Module,
filter_bias_and_bn: bool = True
):
"""
Create optimizer from arguments (legacy interface).
Args:
args: Arguments namespace with optimizer configuration
model: Model to optimize
filter_bias_and_bn: Filter bias and batch norm parameters
Returns:
Configured optimizer
"""
def list_optimizers() -> List[str]:
"""
List available optimizer names.
Returns:
List of supported optimizer names
"""
def get_optimizer_class(optimizer_name: str):
"""
Get optimizer class by name.
Args:
optimizer_name: Name of optimizer
Returns:
Optimizer class
"""Functions for creating parameter groups with different learning rates, weight decay, and layer-specific configurations.
def param_groups_layer_decay(
model: torch.nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: List[str] = None,
layer_decay: float = 0.75,
end_lr_scale: float = 1.0
) -> List[dict]:
"""
Create parameter groups with layer-wise learning rate decay.
Args:
model: Model to create parameter groups for
weight_decay: Base weight decay rate
no_weight_decay_list: Parameters to exclude from weight decay
layer_decay: Layer decay factor
end_lr_scale: Learning rate scale for final layer
Returns:
List of parameter group dictionaries
"""
def param_groups_weight_decay(
model: torch.nn.Module,
weight_decay: float = 1e-5,
no_weight_decay_list: List[str] = None
) -> List[dict]:
"""
Create parameter groups with selective weight decay.
Args:
model: Model to create parameter groups for
weight_decay: Weight decay rate
no_weight_decay_list: Parameters to exclude from weight decay
Returns:
List of parameter group dictionaries
"""class AdaBelief(torch.optim.Optimizer):
"""
AdaBelief optimizer.
Args:
params: Iterable of parameters
lr: Learning rate
betas: Beta coefficients
eps: Epsilon for numerical stability
weight_decay: Weight decay coefficient
amsgrad: Use AMSGrad variant
weight_decouple: Decouple weight decay
fixed_decay: Use fixed decay
rectify: Use rectification
"""
def __init__(
self,
params,
lr: float = 1e-3,
betas: tuple = (0.9, 0.999),
eps: float = 1e-16,
weight_decay: float = 0,
amsgrad: bool = False,
weight_decouple: bool = True,
fixed_decay: bool = False,
rectify: bool = True
): ...
class Lamb(torch.optim.Optimizer):
"""
LAMB (Layer-wise Adaptive Moments) optimizer.
Args:
params: Iterable of parameters
lr: Learning rate
betas: Beta coefficients
eps: Epsilon for numerical stability
weight_decay: Weight decay coefficient
grad_averaging: Use gradient averaging
max_grad_norm: Maximum gradient norm
trust_clip: Trust region clipping
always_adapt: Always adapt learning rate
"""
def __init__(
self,
params,
lr: float = 1e-3,
betas: tuple = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.01,
grad_averaging: bool = True,
max_grad_norm: float = 1.0,
trust_clip: bool = False,
always_adapt: bool = False
): ...
class Lion(torch.optim.Optimizer):
"""
Lion (EvoLved Sign Momentum) optimizer.
Args:
params: Iterable of parameters
lr: Learning rate
betas: Beta coefficients for momentum
weight_decay: Weight decay coefficient
use_triton: Use Triton kernel implementation
"""
def __init__(
self,
params,
lr: float = 1e-4,
betas: tuple = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False
): ...
class Lookahead(torch.optim.Optimizer):
"""
Lookahead optimizer wrapper.
Args:
base_optimizer: Base optimizer to wrap
alpha: Lookahead step size
k: Lookahead frequency
pullback_momentum: Pullback momentum mode
"""
def __init__(
self,
base_optimizer: torch.optim.Optimizer,
alpha: float = 0.5,
k: int = 6,
pullback_momentum: str = "none"
): ...def create_scheduler_v2(
optimizer: torch.optim.Optimizer,
sched: str = 'cosine',
num_epochs: int = 300,
decay_epochs: int = 90,
decay_milestones: List[int] = (90, 180, 270),
cooldown_epochs: int = 0,
patience_epochs: int = 10,
decay_rate: float = 0.1,
min_lr: float = 0,
warmup_lr: float = 1e-5,
warmup_epochs: int = 0,
warmup_prefix: bool = False,
noise: Union[float, List[float]] = None,
noise_pct: float = 0.67,
noise_std: float = 1.0,
noise_seed: int = 42,
cycle_mul: float = 1.0,
cycle_decay: float = 0.1,
cycle_limit: int = 1,
k_decay: float = 1.0,
plateau_mode: str = 'max',
step_on_epochs: bool = True,
updates_per_epoch: int = 0
):
"""
Create learning rate scheduler with v2 interface.
Args:
optimizer: Optimizer instance
sched: Scheduler type ('step', 'cosine', 'tanh', 'poly', 'plateau', etc.)
num_epochs: Total number of training epochs
decay_epochs: Epochs between learning rate decay
decay_rate: Learning rate decay factor
min_lr: Minimum learning rate
warmup_lr: Warmup initial learning rate
warmup_epochs: Number of warmup epochs
cooldown_epochs: Number of cooldown epochs
patience_epochs: Patience for plateau scheduler
cycle_mul: Cycle length multiplier
cycle_decay: Cycle decay factor
cycle_limit: Maximum number of cycles
noise_range: Learning rate noise range
noise_pct: Noise percentage
noise_std: Noise standard deviation
noise_seed: Random seed for noise
k_decay: K decay factor
plateau_mode: Plateau mode ('min' or 'max')
step_on_epochs: Step on epochs vs iterations
updates_per_epoch: Updates per epoch for iteration-based stepping
**kwargs: Additional scheduler arguments
Returns:
Configured scheduler instance
"""
def scheduler_kwargs(args) -> dict:
"""
Extract scheduler keyword arguments from args.
Args:
args: Arguments namespace
Returns:
Dictionary of scheduler arguments
"""class CosineLRScheduler:
"""
Cosine annealing learning rate scheduler with warm restarts.
Args:
optimizer: Optimizer instance
t_initial: Initial number of epochs/iterations
lr_min: Minimum learning rate
cycle_mul: Cycle length multiplier
cycle_decay: Cycle amplitude decay
cycle_limit: Maximum number of cycles
warmup_t: Warmup iterations
warmup_lr_init: Initial warmup learning rate
warmup_prefix: Warmup before first cycle
t_in_epochs: Interpret t_initial as epochs
noise_range_t: Noise range for time
noise_pct: Noise percentage
noise_std: Noise standard deviation
noise_seed: Random seed
initialize: Initialize learning rates
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min: float = 0.0,
cycle_mul: float = 1.0,
cycle_decay: float = 1.0,
cycle_limit: int = 1,
warmup_t: int = 0,
warmup_lr_init: float = 0,
warmup_prefix: bool = False,
t_in_epochs: bool = True,
noise_range_t: tuple = None,
noise_pct: float = 0.67,
noise_std: float = 1.0,
noise_seed: int = None,
initialize: bool = True
): ...
class StepLRScheduler:
"""
Step learning rate scheduler.
Args:
optimizer: Optimizer instance
decay_t: Step intervals for decay
decay_rate: Decay factor
warmup_t: Warmup iterations
warmup_lr_init: Initial warmup learning rate
t_in_epochs: Interpret intervals as epochs
noise_range_t: Noise range for time
noise_pct: Noise percentage
noise_std: Noise standard deviation
noise_seed: Random seed
initialize: Initialize learning rates
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
decay_t: Union[int, List[int]],
decay_rate: float = 0.1,
warmup_t: int = 0,
warmup_lr_init: float = 0,
t_in_epochs: bool = True,
noise_range_t: tuple = None,
noise_pct: float = 0.67,
noise_std: float = 1.0,
noise_seed: int = None,
initialize: bool = True
): ...
class PlateauLRScheduler:
"""
Plateau-based learning rate scheduler.
Args:
optimizer: Optimizer instance
decay_rate: Decay factor when plateau detected
patience_t: Patience before decay
verbose: Print decay messages
threshold: Threshold for measuring improvement
cooldown_t: Cooldown period after decay
mode: Mode for plateau detection ('min' or 'max')
lr_min: Minimum learning rate
warmup_t: Warmup iterations
warmup_lr_init: Initial warmup learning rate
t_in_epochs: Interpret intervals as epochs
noise_range_t: Noise range for time
noise_pct: Noise percentage
noise_std: Noise standard deviation
noise_seed: Random seed
initialize: Initialize learning rates
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
decay_rate: float = 0.1,
patience_t: int = 10,
verbose: bool = True,
threshold: float = 1e-4,
cooldown_t: int = 0,
mode: str = 'max',
lr_min: float = 0,
warmup_t: int = 0,
warmup_lr_init: float = 0,
t_in_epochs: bool = True,
noise_range_t: tuple = None,
noise_pct: float = 0.67,
noise_std: float = 1.0,
noise_seed: int = None,
initialize: bool = True
): ...class LabelSmoothingCrossEntropy(torch.nn.Module):
"""
Cross entropy loss with label smoothing.
Args:
smoothing: Label smoothing factor (0.0 to 1.0)
weight: Class weights for unbalanced datasets
reduction: Loss reduction ('mean', 'sum', 'none')
"""
def __init__(
self,
smoothing: float = 0.1,
weight: torch.Tensor = None,
reduction: str = 'mean'
): ...
class SoftTargetCrossEntropy(torch.nn.Module):
"""
Cross entropy loss with soft targets (for knowledge distillation).
Args:
weight: Class weights
size_average: Deprecated, use reduction
ignore_index: Index to ignore in loss computation
reduce: Deprecated, use reduction
reduction: Loss reduction ('mean', 'sum', 'none')
"""
def __init__(
self,
weight: torch.Tensor = None,
size_average: bool = None,
ignore_index: int = -100,
reduce: bool = None,
reduction: str = 'mean'
): ...
class JsdCrossEntropy(torch.nn.Module):
"""
Jensen-Shannon divergence cross entropy loss.
Args:
num_splits: Number of augmentation splits
alpha: Mixing parameter for splits
weight: Class weights
size_average: Deprecated, use reduction
ignore_index: Index to ignore
reduce: Deprecated, use reduction
reduction: Loss reduction
smoothing: Label smoothing factor
"""
def __init__(
self,
num_splits: int = 2,
alpha: float = 12.0,
weight: torch.Tensor = None,
size_average: bool = None,
ignore_index: int = -100,
reduce: bool = None,
reduction: str = 'mean',
smoothing: float = 0.1
): ...
class BinaryCrossEntropy(torch.nn.Module):
"""
Binary cross entropy loss with optional smoothing.
Args:
smoothing: Label smoothing factor
target_threshold: Threshold for hard targets
weight: Class weights
reduction: Loss reduction
pos_weight: Positive class weight
"""
def __init__(
self,
smoothing: float = 0.0,
target_threshold: float = None,
weight: torch.Tensor = None,
reduction: str = 'mean',
pos_weight: torch.Tensor = None
): ...
class AsymmetricLossMultiLabel(torch.nn.Module):
"""
Asymmetric loss for multi-label classification.
Args:
gamma_neg: Focusing parameter for negative examples
gamma_pos: Focusing parameter for positive examples
clip: Clipping value for probability
eps: Epsilon for numerical stability
disable_torch_grad_focal_loss: Disable gradient computation
"""
def __init__(
self,
gamma_neg: float = 4,
gamma_pos: float = 1,
clip: float = 0.05,
eps: float = 1e-8,
disable_torch_grad_focal_loss: bool = False
): ...class ModelEma:
"""
Model Exponential Moving Average.
Args:
model: Model to track
decay: EMA decay rate
device: Device for EMA parameters
resume: Resume from checkpoint path
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.9999,
device: torch.device = None,
resume: str = ''
): ...
def update(self, model: torch.nn.Module) -> None:
"""Update EMA parameters."""
def set(self, model: torch.nn.Module) -> None:
"""Set EMA parameters from model."""
class ModelEmaV2:
"""
Model EMA v2 with improved decay adjustment.
Args:
model: Model to track
decay: Base decay rate
decay_type: Decay adjustment type
device: Device for EMA parameters
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.9999,
decay_type: str = 'exponential',
device: torch.device = None
): ...def adaptive_clip_grad(
parameters,
clip_factor: float = 0.01,
eps: float = 1e-3,
norm_type: float = 2.0
) -> torch.Tensor:
"""
Adaptive gradient clipping.
Args:
parameters: Model parameters
clip_factor: Adaptive clipping factor
eps: Epsilon for numerical stability
norm_type: Norm type for gradient computation
Returns:
Gradient norm
"""
def dispatch_clip_grad(
parameters,
value: float,
mode: str = 'norm',
norm_type: float = 2.0
) -> torch.Tensor:
"""
Dispatch gradient clipping method.
Args:
parameters: Model parameters
value: Clipping value
mode: Clipping mode ('norm', 'value', 'agc')
norm_type: Norm type for gradient computation
Returns:
Gradient norm
"""class CheckpointSaver:
"""
Model checkpoint saver with configurable retention policy.
Args:
model: Model to save
optimizer: Optimizer to save
args: Training arguments
model_ema: EMA model to save
amp_scaler: AMP scaler to save
checkpoint_prefix: Checkpoint filename prefix
recovery_prefix: Recovery checkpoint prefix
checkpoint_dir: Directory for checkpoints
recovery_dir: Directory for recovery checkpoints
decreasing: Monitor decreasing metric
max_history: Maximum checkpoint history
unwrap_fn: Function to unwrap model
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
args=None,
model_ema: ModelEma = None,
amp_scaler=None,
checkpoint_prefix: str = 'checkpoint',
recovery_prefix: str = 'recovery',
checkpoint_dir: str = '',
recovery_dir: str = '',
decreasing: bool = False,
max_history: int = 10,
unwrap_fn: Callable = None
): ...
def save_checkpoint(
self,
epoch: int,
metric: float = None
) -> str:
"""Save checkpoint."""
def save_recovery(self, epoch: int, batch_idx: int = 0) -> str:
"""Save recovery checkpoint."""class AverageMeter:
"""
Computes and stores the average and current value.
Args:
name: Meter name
fmt: Format string for display
"""
def __init__(self, name: str = '', fmt: str = ':f'): ...
def reset(self) -> None:
"""Reset all statistics."""
def update(self, val: float, n: int = 1) -> None:
"""Update with new value."""
def accuracy(
output: torch.Tensor,
target: torch.Tensor,
topk: tuple = (1,)
) -> List[torch.Tensor]:
"""
Compute accuracy for specified top-k values.
Args:
output: Model predictions
target: Ground truth labels
topk: Top-k values to compute
Returns:
List of accuracy values for each k
"""import timm
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
from timm.loss import LabelSmoothingCrossEntropy
from timm.utils import ModelEma, CheckpointSaver, AverageMeter
# Create model
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
# Create optimizer with layer decay
optimizer = create_optimizer_v2(
model,
opt='adamw',
lr=1e-3,
weight_decay=0.05
)
# Create learning rate scheduler
scheduler = create_scheduler_v2(
optimizer,
sched='cosine',
num_epochs=100,
warmup_epochs=5,
warmup_lr=1e-5,
min_lr=1e-6
)
# Create loss function
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
# Create EMA
model_ema = ModelEma(model, decay=0.9999)
# Create checkpoint saver
saver = CheckpointSaver(
model=model,
optimizer=optimizer,
model_ema=model_ema,
checkpoint_dir='./checkpoints',
max_history=5
)
# Metrics
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')from timm.optim import param_groups_layer_decay, Lamb, Lookahead
# Create parameter groups with layer decay
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.8
)
# Create LAMB optimizer
base_optimizer = Lamb(param_groups, lr=1e-3)
# Wrap with Lookahead
optimizer = Lookahead(base_optimizer, alpha=0.5, k=6)from typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch
# Optimizer and scheduler types
OptimizerType = torch.optim.Optimizer
SchedulerType = torch.optim.lr_scheduler._LRScheduler
# Parameter types
ParamGroup = Dict[str, Any]
ParamGroups = List[ParamGroup]
# Loss function type
LossFunction = torch.nn.Module
# Metric types
MetricValue = Union[float, torch.Tensor]
MetricDict = Dict[str, MetricValue]Install with Tessl CLI
npx tessl i tessl/pypi-timm