PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.
Functions for model management, parameter manipulation, and model state operations.
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
"""
Unwrap model from DDP/EMA/other wrappers.
Args:
model: Wrapped model instance
Returns:
Unwrapped base model
"""
def get_state_dict(
model: torch.nn.Module,
unwrap_fn: Callable = unwrap_model
) -> Dict[str, Any]:
"""
Get model state dictionary with unwrapping.
Args:
model: Model to get state dict from
unwrap_fn: Function to unwrap model
Returns:
Model state dictionary
"""
def freeze(model: torch.nn.Module) -> None:
"""
Freeze all model parameters (disable gradients).
Args:
model: Model to freeze
"""
def unfreeze(model: torch.nn.Module) -> None:
"""
Unfreeze all model parameters (enable gradients).
Args:
model: Model to unfreeze
"""
def reparameterize_model(
model: torch.nn.Module,
**kwargs
) -> torch.nn.Module:
"""
Reparameterize model for inference optimization.
Args:
model: Model to reparameterize
**kwargs: Reparameterization options
Returns:
Reparameterized model
"""Functions for initializing and managing distributed training across multiple devices and nodes.
def init_distributed_device(args) -> Tuple[torch.device, int]:
"""
Initialize distributed training device and process rank.
Args:
args: Arguments namespace with distributed training configuration
Returns:
Tuple of (device, world_size) for distributed training setup
"""
def distribute_bn(
model: torch.nn.Module,
world_size: int,
reduce: bool = False
) -> None:
"""
Distribute batch normalization statistics across processes.
Args:
model: Model with batch norm layers
world_size: Number of distributed processes
reduce: Reduce statistics across processes
"""
def reduce_tensor(
tensor: torch.Tensor,
world_size: int = 1
) -> torch.Tensor:
"""
Reduce tensor across distributed processes.
Args:
tensor: Tensor to reduce
world_size: Number of processes
Returns:
Reduced tensor
"""
def world_info_from_env() -> Tuple[int, int, int]:
"""
Get distributed world info from environment variables.
Returns:
Tuple of (local_rank, world_rank, world_size)
"""
def is_distributed_env() -> bool:
"""
Check if running in distributed environment.
Returns:
True if distributed environment detected
"""Utilities for managing mixed precision training with automatic mixed precision (AMP).
class ApexScaler:
"""
Gradient scaler using NVIDIA Apex.
Args:
loss_scale: Initial loss scaling factor
init_scale: Initial scale value
scale_factor: Scale adjustment factor
scale_window: Scale adjustment window
"""
def __init__(
self,
loss_scale: str = 'dynamic',
init_scale: float = 2.**16,
scale_factor: float = 2.0,
scale_window: int = 2000
): ...
def scale_loss(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer): ...
def unscale_grads(self, optimizer: torch.optim.Optimizer): ...
def update_scale(self, overflow: bool): ...
class NativeScaler:
"""
Native PyTorch gradient scaler for mixed precision.
Args:
enabled: Enable gradient scaling
init_scale: Initial scaling factor
growth_factor: Scale growth factor
backoff_factor: Scale backoff factor
growth_interval: Interval for scale growth
"""
def __init__(
self,
enabled: bool = True,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000
): ...
def scale(self, loss: torch.Tensor) -> torch.Tensor: ...
def step(self, optimizer: torch.optim.Optimizer) -> None: ...
def update(self) -> None: ...Functions for managing CUDA operations, JIT compilation, and performance optimization.
def set_jit_legacy(enable: bool) -> None:
"""
Set legacy JIT mode.
Args:
enable: Enable legacy JIT mode
"""
def set_jit_fuser(fuser_name: str) -> None:
"""
Set JIT fuser type.
Args:
fuser_name: Name of fuser ('te', 'old', 'nvfuser')
"""
def random_seed(seed: int, rank: int = 0) -> None:
"""
Set random seed for reproducibility across all libraries.
Args:
seed: Random seed value
rank: Process rank for distributed training
"""Utilities for setting up logging, argument parsing, and experiment configuration.
def setup_default_logging(
default_level: int = logging.INFO,
log_path: str = '',
**kwargs
) -> None:
"""
Setup default logging configuration.
Args:
default_level: Default logging level
log_path: Path for log file
**kwargs: Additional logging configuration
"""
def natural_key(string_: str) -> List[Union[int, str]]:
"""
Natural sorting key function for strings with numbers.
Args:
string_: String to create key for
Returns:
List of components for natural sorting
"""
def add_bool_arg(
parser,
name: str,
default: bool = False,
help: str = ''
) -> None:
"""
Add boolean argument to argument parser with --name/--no-name pattern.
Args:
parser: ArgumentParser instance
name: Argument name
default: Default value
help: Help text
"""Functions for managing training outputs, experiment directories, and result summaries.
def update_summary(
epoch: int,
train_metrics: Dict[str, float],
eval_metrics: Dict[str, float],
filename: str,
lr: float = None,
write_header: bool = False,
log_wandb: bool = False
) -> None:
"""
Update training summary with metrics.
Args:
epoch: Current epoch
train_metrics: Training metrics dictionary
eval_metrics: Evaluation metrics dictionary
filename: Summary file path
lr: Current learning rate
write_header: Write CSV header
log_wandb: Log to Weights & Biases
"""
def get_outdir(path: str, *paths: str, inc: bool = False) -> str:
"""
Get output directory for experiments.
Args:
path: Base output path
*paths: Additional path components
inc: Auto-increment directory name
Returns:
Output directory path
"""class AverageMeter:
"""
Computes and stores the average and current value for metrics tracking.
Args:
name: Name of the metric
fmt: Format string for display
"""
def __init__(self, name: str = '', fmt: str = ':f'): ...
def reset(self) -> None:
"""Reset all statistics to initial values."""
def update(self, val: float, n: int = 1) -> None:
"""
Update meter with new value.
Args:
val: New value to add
n: Number of samples the value represents
"""
def __str__(self) -> str:
"""String representation of current meter state."""
def accuracy(
output: torch.Tensor,
target: torch.Tensor,
topk: Tuple[int, ...] = (1,)
) -> List[torch.Tensor]:
"""
Compute accuracy for specified top-k values.
Args:
output: Model output predictions [batch_size, num_classes]
target: Ground truth labels [batch_size]
topk: Tuple of k values for top-k accuracy
Returns:
List of accuracy tensors for each k value
"""class ModelEma:
"""
Model Exponential Moving Average for maintaining shadow weights.
Args:
model: Model to track with EMA
decay: EMA decay rate (default: 0.9999)
device: Device to store EMA parameters
resume: Path to resume EMA from checkpoint
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.9999,
device: torch.device = None,
resume: str = ''
): ...
def update(self, model: torch.nn.Module) -> None:
"""
Update EMA parameters from model.
Args:
model: Source model for updates
"""
def set(self, model: torch.nn.Module) -> None:
"""
Set EMA parameters from model (copy all parameters).
Args:
model: Source model to copy from
"""
class ModelEmaV2:
"""
Model EMA v2 with improved decay adjustment based on training progress.
Args:
model: Model to track
decay: Base decay rate
decay_type: Type of decay adjustment ('exponential', 'linear')
device: Device for EMA parameters
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.9999,
decay_type: str = 'exponential',
device: torch.device = None
): ...
class ModelEmaV3:
"""
Model EMA v3 with performance optimizations and memory efficiency.
Args:
model: Model to track
decay: EMA decay rate
update_after_step: Steps before starting EMA updates
use_ema_warmup: Use warmup for EMA updates
inv_gamma: Inverse gamma for warmup
power: Power for warmup
min_value: Minimum decay value
device: Device for parameters
"""
def __init__(
self,
model: torch.nn.Module,
decay: float = 0.9999,
update_after_step: int = 100,
use_ema_warmup: bool = False,
inv_gamma: float = 1.0,
power: float = 2/3,
min_value: float = 0.0,
device: torch.device = None
): ...class CheckpointSaver:
"""
Saves model checkpoints with configurable retention and recovery policies.
Args:
model: Model to save
optimizer: Optimizer state to save
args: Training arguments/configuration
model_ema: EMA model to save
amp_scaler: Mixed precision scaler
checkpoint_prefix: Prefix for checkpoint filenames
recovery_prefix: Prefix for recovery checkpoints
checkpoint_dir: Directory for regular checkpoints
recovery_dir: Directory for recovery checkpoints
decreasing: Whether monitored metric is decreasing (lower is better)
max_history: Maximum number of checkpoints to keep
unwrap_fn: Function to unwrap model before saving
"""
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
args = None,
model_ema: ModelEma = None,
amp_scaler = None,
checkpoint_prefix: str = 'checkpoint',
recovery_prefix: str = 'recovery',
checkpoint_dir: str = '',
recovery_dir: str = '',
decreasing: bool = False,
max_history: int = 10,
unwrap_fn: Callable = unwrap_model
): ...
def save_checkpoint(
self,
epoch: int,
metric: float = None
) -> Tuple[str, bool]:
"""
Save checkpoint if metric improved.
Args:
epoch: Current epoch number
metric: Metric value for comparison
Returns:
Tuple of (checkpoint_path, is_best)
"""
def save_recovery(
self,
epoch: int,
batch_idx: int = 0
) -> str:
"""
Save recovery checkpoint for resuming interrupted training.
Args:
epoch: Current epoch
batch_idx: Current batch index
Returns:
Path to saved recovery checkpoint
"""import logging
import timm
from timm.utils import (
setup_default_logging, random_seed, ModelEma,
CheckpointSaver, AverageMeter, accuracy
)
# Setup logging
setup_default_logging(log_path='training.log')
logger = logging.getLogger(__name__)
# Set random seed for reproducibility
random_seed(42, rank=0)
# Create model and training components
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Setup EMA tracking
model_ema = ModelEma(model, decay=0.9999)
# Setup checkpoint saving
saver = CheckpointSaver(
model=model,
optimizer=optimizer,
model_ema=model_ema,
checkpoint_dir='./checkpoints',
max_history=5,
decreasing=False # Higher accuracy is better
)
# Setup metrics tracking
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')from timm.utils import (
init_distributed_device, distribute_bn, reduce_tensor,
is_distributed_env
)
# Initialize distributed training
device, world_size = init_distributed_device(args)
model = model.to(device)
if is_distributed_env():
# Synchronize batch norm statistics
distribute_bn(model, world_size, reduce=True)
# Wrap model for distributed training
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device], find_unused_parameters=False
)
# In training loop - reduce metrics across processes
def train_epoch(model, loader, optimizer, device, world_size):
losses = AverageMeter('Loss')
for batch_idx, (input, target) in enumerate(loader):
input, target = input.to(device), target.to(device)
output = model(input)
loss = criterion(output, target)
# Backward and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Reduce loss across processes
if world_size > 1:
loss = reduce_tensor(loss, world_size)
losses.update(loss.item(), input.size(0))
return losses.avgfrom timm.utils import NativeScaler
# Setup mixed precision training
scaler = NativeScaler()
model = model.to(device)
def train_step(model, input, target, optimizer, scaler):
optimizer.zero_grad()
# Forward pass with autocast
with torch.cuda.amp.autocast():
output = model(input)
loss = criterion(output, target)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return loss.item()def train_model():
setup_default_logging()
random_seed(42)
# Model setup
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# Training utilities
model_ema = ModelEmaV2(model, decay=0.9999)
scaler = NativeScaler()
saver = CheckpointSaver(
model, optimizer, model_ema=model_ema, amp_scaler=scaler,
checkpoint_dir='./checkpoints'
)
# Metrics
train_losses = AverageMeter('Train Loss')
train_acc1 = AverageMeter('Train Acc@1')
for epoch in range(num_epochs):
# Training
model.train()
train_losses.reset()
train_acc1.reset()
for batch_idx, (input, target) in enumerate(train_loader):
input, target = input.to(device), target.to(device)
# Mixed precision forward pass
with torch.cuda.amp.autocast():
output = model(input)
loss = criterion(output, target)
# Backward pass
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Update EMA
model_ema.update(model)
# Metrics
acc1, acc5 = accuracy(output, target, topk=(1, 5))
train_losses.update(loss.item(), input.size(0))
train_acc1.update(acc1.item(), input.size(0))
# Validation and checkpointing
val_acc = validate(model_ema.module, val_loader)
saver.save_checkpoint(epoch, val_acc)
logger.info(f'Epoch {epoch}: Train Loss {train_losses.avg:.4f}, '
f'Train Acc {train_acc1.avg:.2f}%, Val Acc {val_acc:.2f}%')from typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch
import logging
# Device and distributed types
DeviceType = torch.device
WorldInfo = Tuple[int, int, int] # (local_rank, world_rank, world_size)
# Metrics types
MetricValue = Union[float, torch.Tensor]
MetricDict = Dict[str, MetricValue]
# Checkpoint types
CheckpointDict = Dict[str, Any]
UnwrapFunction = Callable[[torch.nn.Module], torch.nn.Module]
# Scaler types
LossScaler = Union[torch.cuda.amp.GradScaler, Any]
# Logging types
LogLevel = int
Logger = logging.Logger
# Utility function types
SeedFunction = Callable[[int, int], None]
ReduceFunction = Callable[[torch.Tensor, int], torch.Tensor]Install with Tessl CLI
npx tessl i tessl/pypi-timm