PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
npx @tessl/cli install tessl/pypi-timm@1.0.0A comprehensive collection of image models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and reference training/validation scripts that provide state-of-the-art computer vision models with reproducible ImageNet training results.
pip install timmimport timmCommon patterns for model creation:
from timm import create_model, list_modelsFor working with specific components:
from timm.data import create_loader, create_transform
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2
from timm.loss import LabelSmoothingCrossEntropyimport timm
import torch
# Create a pretrained model
model = timm.create_model('resnet50', pretrained=True, num_classes=1000)
# List available models
available_models = timm.list_models('*resnet*') # All ResNet variants
pretrained_models = timm.list_pretrained('efficientnet*') # EfficientNet models with pretrained weights
# Create model for feature extraction
feature_model = timm.create_model('resnet50', pretrained=True, features_only=True)
# Inference on an image
model.eval()
with torch.no_grad():
# Input tensor should be [batch_size, 3, height, width]
input_tensor = torch.randn(1, 3, 224, 224)
predictions = model(input_tensor)
# Get model configuration
cfg = timm.get_pretrained_cfg('resnet50')
print(f"Model input size: {cfg['input_size']}")
print(f"Model mean: {cfg['mean']}")
print(f"Model std: {cfg['std']}")TIMM is organized around several key components that work together to provide a complete computer vision ecosystem:
The library's modular design allows users to mix and match components, from using pretrained models for inference to building custom training pipelines with TIMM's optimizers and data loaders.
Core functionality for discovering, creating, and configuring computer vision models from TIMM's extensive collection of architectures.
def create_model(
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs: Any
) -> torch.nn.Module: ...
def list_models(
filter: Union[str, List[str]] = '',
module: Union[str, List[str]] = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None
) -> List[str]: ...
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = ''
) -> List[str]: ...
def is_model(model_name: str) -> bool: ...
def list_modules() -> List[str]: ...
def model_entrypoint(
model_name: str,
module_filter: Optional[str] = None
) -> Callable[..., Any]: ...
def is_model_pretrained(model_name: str) -> bool: ...
def get_pretrained_cfg(
model_name: str,
allow_unregistered: bool = True
) -> Optional[PretrainedCfg]: ...
def get_pretrained_cfg_value(
model_name: str,
cfg_key: str
) -> Optional[Any]: ...Complete data pipeline including datasets, transforms, augmentation strategies, and high-performance data loaders optimized for computer vision training and inference.
def create_loader(
dataset,
input_size: int | tuple,
batch_size: int,
is_training: bool = False,
use_prefetcher: bool = False,
no_aug: bool = False,
**kwargs
) -> torch.utils.data.DataLoader: ...
def create_transform(
input_size: int | tuple,
is_training: bool = False,
use_prefetcher: bool = False,
no_aug: bool = False,
scale: tuple = (0.08, 1.0),
ratio: tuple = (3./4., 4./3.),
**kwargs
): ...
def create_dataset(
name: str,
root: str,
split: str = 'validation',
is_training: bool = False,
**kwargs
): ...Extensive collection of neural network building blocks including activations, attention mechanisms, convolutions, normalization layers, and specialized components for vision architectures.
# Layer creation utilities
def create_conv2d(
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int]],
**kwargs
) -> torch.nn.Module: ...
def create_norm_layer(
layer_name: str,
num_features: int,
**kwargs
) -> torch.nn.Module: ...
def create_act_layer(
name: Optional[str],
inplace: Optional[bool] = None,
**kwargs
) -> Optional[torch.nn.Module]: ...
# Configuration functions
def is_scriptable() -> bool: ...
def is_exportable() -> bool: ...
def set_scriptable(mode: bool) -> object: ... # Context manager
def set_exportable(mode: bool) -> object: ... # Context managerComprehensive training utilities including optimizers, learning rate schedulers, loss functions, and training helpers for building complete training pipelines.
def create_optimizer_v2(
model_or_params,
opt: str = 'sgd',
lr: float = 0.01,
weight_decay: float = 0.0,
momentum: float = 0.9,
**kwargs
): ...
def create_scheduler_v2(
optimizer,
sched: str = 'step',
epochs: int = 200,
**kwargs
): ...
# Loss functions
class LabelSmoothingCrossEntropy(torch.nn.Module): ...
class SoftTargetCrossEntropy(torch.nn.Module): ...Advanced functionality for extracting features from models, analyzing model architecture, and manipulating pretrained models for custom use cases.
def create_feature_extractor(
model: torch.nn.Module,
return_nodes: dict | list,
**kwargs
): ...
class FeatureHookNet(torch.nn.Module): ...
class FeatureDictNet(torch.nn.Module): ...
# Model manipulation
def adapt_input_conv(
model: torch.nn.Module,
in_chans: int,
conv_layer: str = None
): ...General utilities for distributed training, model management, checkpointing, logging, and other supporting functionality for production computer vision workflows.
# Model utilities
def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: ...
def freeze(model: torch.nn.Module) -> None: ...
def unfreeze(model: torch.nn.Module) -> None: ...
# Training utilities
class ModelEma: ...
class CheckpointSaver: ...
class AverageMeter: ...
# Distributed training
def init_distributed_device(args) -> tuple: ...
def reduce_tensor(tensor: torch.Tensor) -> torch.Tensor: ...from typing import Optional, Union, List, Dict, Tuple, Callable, Any
from pathlib import Path
import torch
# Common type aliases used throughout TIMM
ModelType = torch.nn.Module
OptimizerType = torch.optim.Optimizer
SchedulerType = torch.optim.lr_scheduler._LRScheduler
TransformType = Callable[[Any], torch.Tensor]
DatasetType = torch.utils.data.Dataset
LoaderType = torch.utils.data.DataLoader
# Configuration types
ConfigDict = Dict[str, Any]
PretrainedCfg = Dict[str, Any]
ModelCfg = Dict[str, Any]