PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks
Comprehensive data pipeline including datasets, transforms, augmentation strategies, and high-performance data loaders optimized for computer vision training and inference.
High-performance data loaders with support for distributed training, mixed precision, and advanced augmentation techniques.
def create_loader(
dataset,
input_size: Union[int, tuple],
batch_size: int,
is_training: bool = False,
use_prefetcher: bool = False,
no_aug: bool = False,
re_prob: float = 0.0,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
scale: tuple = (0.08, 1.0),
ratio: tuple = (3./4., 4./3.),
hflip: float = 0.5,
vflip: float = 0.0,
color_jitter: float = 0.4,
auto_augment: str = None,
num_aug_repeats: int = 0,
num_aug_splits: int = 0,
interpolation: str = 'bilinear',
mean: tuple = IMAGENET_DEFAULT_MEAN,
std: tuple = IMAGENET_DEFAULT_STD,
num_workers: int = 1,
distributed: bool = False,
collate_fn: Callable = None,
pin_memory: bool = False,
use_multi_epochs_loader: bool = False,
persistent_workers: bool = True,
worker_seeding: str = 'all',
**kwargs
) -> torch.utils.data.DataLoader:
"""
Create a DataLoader with TIMM's optimized configuration.
Args:
dataset: Dataset instance
input_size: Target input size (int or tuple)
batch_size: Batch size for training/inference
is_training: Training mode with augmentations
use_prefetcher: Use CUDA prefetcher for performance
no_aug: Disable augmentations
re_prob: Random erasing probability
re_mode: Random erasing mode ('const', 'rand', 'pixel')
re_count: Random erasing count
re_num_splits: Random erasing number of splits
scale: Random resized crop scale range
ratio: Random resized crop aspect ratio range
hflip: Horizontal flip probability
vflip: Vertical flip probability
color_jitter: Color jitter factor
auto_augment: AutoAugment policy ('original', 'originalr', 'v0', 'v0r')
num_aug_repeats: Number of augmentation repetitions
num_aug_splits: Number of augmentation splits
interpolation: Resize interpolation method
mean: Normalization mean values
std: Normalization standard deviation values
num_workers: Number of data loading workers
distributed: Enable distributed sampler
collate_fn: Custom collate function
pin_memory: Pin memory for GPU transfer
use_multi_epochs_loader: Use multi-epoch loader for efficiency
persistent_workers: Keep workers alive between epochs
worker_seeding: Worker random seeding strategy
Returns:
Configured DataLoader instance
"""Comprehensive transform pipelines with support for training and inference configurations.
def create_transform(
input_size: Union[int, tuple],
is_training: bool = False,
use_prefetcher: bool = False,
no_aug: bool = False,
scale: tuple = (0.08, 1.0),
ratio: tuple = (3./4., 4./3.),
hflip: float = 0.5,
vflip: float = 0.0,
color_jitter: float = 0.4,
auto_augment: str = None,
interpolation: str = 'bilinear',
mean: tuple = IMAGENET_DEFAULT_MEAN,
std: tuple = IMAGENET_DEFAULT_STD,
re_prob: float = 0.0,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
crop_pct: float = None,
tf_preprocessing: bool = False,
separate: bool = False,
**kwargs
):
"""
Create image transform pipeline.
Args:
input_size: Target input size for transforms
is_training: Use training transforms with augmentation
use_prefetcher: Skip normalization for CUDA prefetcher
no_aug: Disable all augmentations
scale: Random resized crop scale range
ratio: Random resized crop aspect ratio range
hflip: Horizontal flip probability
vflip: Vertical flip probability
color_jitter: Color jitter strength
auto_augment: AutoAugment policy name
interpolation: Resize interpolation method
mean: Normalization mean values
std: Normalization standard deviation values
re_prob: Random erasing probability
re_mode: Random erasing mode
re_count: Random erasing count
re_num_splits: Random erasing splits
crop_pct: Center crop percentage
tf_preprocessing: Use TensorFlow-style preprocessing
separate: Return transforms as separate list
Returns:
Transform function or list of transforms
"""Factory function for creating various dataset types including ImageNet, CIFAR, and custom datasets.
def create_dataset(
name: str,
root: str,
split: str = 'validation',
is_training: bool = False,
class_map: dict = None,
load_bytes: bool = False,
img_mode: str = 'RGB',
transform: Callable = None,
target_transform: Callable = None,
**kwargs
):
"""
Create dataset instance.
Args:
name: Dataset name or path pattern
root: Root directory containing dataset
split: Dataset split ('train', 'validation', 'test')
is_training: Training mode configuration
class_map: Custom class mapping
load_bytes: Load images as bytes instead of PIL
img_mode: Image mode ('RGB', 'L', etc.)
transform: Image transforms
target_transform: Target/label transforms
**kwargs: Dataset-specific arguments
Returns:
Dataset instance
"""class ImageDataset(torch.utils.data.Dataset):
"""
Standard image dataset for classification tasks.
Args:
root: Root directory path
reader: Image reader instance
class_to_idx: Class name to index mapping
transform: Image transforms
target_transform: Target transforms
"""
def __init__(
self,
root: str,
reader: Optional[Any] = None,
class_to_idx: Optional[dict] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None
): ...
class IterableImageDataset(torch.utils.data.IterableDataset):
"""
Iterable image dataset for streaming large datasets.
Args:
root: Root directory or file pattern
reader: Image reader instance
split: Dataset split name
is_training: Training mode
batch_size: Batch size for iteration
transform: Image transforms
"""
def __init__(
self,
root: str,
reader: Optional[Any] = None,
split: str = 'train',
is_training: bool = False,
batch_size: Optional[int] = None,
transform: Optional[Callable] = None
): ...
class AugMixDataset(torch.utils.data.Dataset):
"""
Dataset wrapper for AugMix augmentation technique.
Args:
dataset: Base dataset
num_splits: Number of augmentation splits
alpha: Mixing parameter
width: Number of augmentation chains
depth: Depth of augmentation chains
blended: Use blended mixing
"""
def __init__(
self,
dataset,
num_splits: int = 2,
alpha: float = 1.0,
width: int = 3,
depth: int = -1,
blended: bool = False
): ...class ToTensor:
"""Convert PIL Image to tensor."""
def __call__(self, pic): ...
class ToNumpy:
"""Convert tensor to numpy array."""
def __call__(self, tensor): ...
class RandomResizedCropAndInterpolation:
"""
Random resized crop with configurable interpolation.
Args:
size: Target output size
scale: Random crop scale range
ratio: Random crop aspect ratio range
interpolation: Interpolation method
"""
def __init__(
self,
size: Union[int, tuple],
scale: tuple = (0.08, 1.0),
ratio: tuple = (3./4., 4./3.),
interpolation: str = 'bilinear'
): ...class RandAugment:
"""
RandAugment augmentation.
Args:
ops: List of augmentation operations
num_layers: Number of augmentation layers to apply
magnitude: Augmentation magnitude
"""
def __init__(
self,
ops: List[str],
num_layers: int = 2,
magnitude: int = 9
): ...
class AutoAugment:
"""
AutoAugment data augmentation.
Args:
policy: AutoAugment policy name
"""
def __init__(self, policy: str = 'original'): ...
class TrivialAugmentWide:
"""TrivialAugment Wide augmentation strategy."""
def __init__(self): ...
class Mixup:
"""
Mixup data augmentation.
Args:
mixup_alpha: Mixup interpolation coefficient
cutmix_alpha: CutMix interpolation coefficient
cutmix_minmax: CutMix min/max box size ratios
prob: Probability of applying mixup/cutmix
switch_prob: Probability of switching between mixup and cutmix
mode: Mixup mode ('batch', 'pair', 'elem')
correct_lam: Apply lambda correction
label_smoothing: Label smoothing value
num_classes: Number of classes
"""
def __init__(
self,
mixup_alpha: float = 1.0,
cutmix_alpha: float = 0.0,
cutmix_minmax: Optional[tuple] = None,
prob: float = 1.0,
switch_prob: float = 0.5,
mode: str = 'batch',
correct_lam: bool = True,
label_smoothing: float = 0.1,
num_classes: int = 1000
): ...# ImageNet normalization constants
IMAGENET_DEFAULT_MEAN: tuple = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD: tuple = (0.229, 0.224, 0.225)
# ImageNet Inception normalization
IMAGENET_INCEPTION_MEAN: tuple = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD: tuple = (0.5, 0.5, 0.5)
# OpenAI CLIP normalization
OPENAI_CLIP_MEAN: tuple = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD: tuple = (0.26862954, 0.26130258, 0.27577711)def resolve_data_config(
args=None,
pretrained_cfg: dict = None,
model: torch.nn.Module = None,
use_test_size: bool = False,
verbose: bool = False
) -> dict:
"""
Resolve data configuration from model, args, or defaults.
Args:
args: Argument namespace with data config
pretrained_cfg: Pretrained model configuration
model: Model instance to extract config from
use_test_size: Use test/inference input size
verbose: Print resolved configuration
Returns:
Dictionary with resolved data configuration
"""
def resolve_model_data_config(
model: torch.nn.Module,
args=None,
pretrained_cfg: dict = None,
use_test_size: bool = False,
verbose: bool = False
) -> dict:
"""
Resolve data configuration specifically from model.
Args:
model: Model instance
args: Additional arguments
pretrained_cfg: Pretrained configuration override
use_test_size: Use inference input size
verbose: Print configuration details
Returns:
Model-specific data configuration
"""def create_reader(
name: str,
root: str,
split: str = 'train',
**kwargs
):
"""
Create image reader for different data formats.
Args:
name: Reader type ('', 'hfds', 'tfds', 'wds')
root: Data root path
split: Dataset split
**kwargs: Reader-specific arguments
Returns:
Configured reader instance
"""
def get_img_extensions() -> set:
"""
Get supported image file extensions.
Returns:
Set of supported extensions
"""
def is_img_extension(filename: str) -> bool:
"""
Check if filename has supported image extension.
Args:
filename: File name to check
Returns:
True if supported image format
"""import timm
from timm.data import create_loader, create_transform, create_dataset
# Create transforms for training and validation
train_transform = create_transform(
input_size=224,
is_training=True,
hflip=0.5,
color_jitter=0.4,
auto_augment='original'
)
val_transform = create_transform(
input_size=224,
is_training=False
)
# Create datasets
train_dataset = create_dataset(
'imagefolder',
root='/path/to/train',
transform=train_transform
)
val_dataset = create_dataset(
'imagefolder',
root='/path/to/val',
transform=val_transform
)
# Create data loaders
train_loader = create_loader(
train_dataset,
input_size=224,
batch_size=32,
is_training=True,
num_workers=4
)
val_loader = create_loader(
val_dataset,
input_size=224,
batch_size=64,
is_training=False,
num_workers=4
)from timm.data import Mixup, create_loader
# Create mixup augmentation
mixup = Mixup(
mixup_alpha=0.8,
cutmix_alpha=1.0,
prob=1.0,
switch_prob=0.5,
mode='batch',
label_smoothing=0.1,
num_classes=1000
)
# Create loader with advanced augmentation
train_loader = create_loader(
dataset,
input_size=224,
batch_size=32,
is_training=True,
auto_augment='rand-m9-mstd0.5-inc1',
re_prob=0.25,
mixup_alpha=0.8,
cutmix_alpha=1.0
)
# Apply mixup in training loop
for batch_idx, (input, target) in enumerate(train_loader):
if mixup is not None:
input, target = mixup(input, target)
# ... training codefrom typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch
# Transform and dataset types
TransformType = Callable[[Any], torch.Tensor]
DatasetType = torch.utils.data.Dataset
LoaderType = torch.utils.data.DataLoader
# Data configuration
DataConfig = Dict[str, Any]
AugmentConfig = Dict[str, Any]
# Common type aliases
ImageSize = Union[int, Tuple[int, int]]
NormStats = Tuple[float, float, float]Install with Tessl CLI
npx tessl i tessl/pypi-timm