CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-timm

PyTorch Image Models library providing state-of-the-art computer vision models, training scripts, and utilities for image classification tasks

Overview
Eval results
Files

data.mddocs/

Data Processing and Loading

Comprehensive data pipeline including datasets, transforms, augmentation strategies, and high-performance data loaders optimized for computer vision training and inference.

Capabilities

Data Loader Creation

High-performance data loaders with support for distributed training, mixed precision, and advanced augmentation techniques.

def create_loader(
    dataset,
    input_size: Union[int, tuple],
    batch_size: int,
    is_training: bool = False,
    use_prefetcher: bool = False,
    no_aug: bool = False,
    re_prob: float = 0.0,
    re_mode: str = 'const',
    re_count: int = 1,
    re_num_splits: int = 0,
    scale: tuple = (0.08, 1.0),
    ratio: tuple = (3./4., 4./3.),
    hflip: float = 0.5,
    vflip: float = 0.0,
    color_jitter: float = 0.4,
    auto_augment: str = None,
    num_aug_repeats: int = 0,
    num_aug_splits: int = 0,
    interpolation: str = 'bilinear',
    mean: tuple = IMAGENET_DEFAULT_MEAN,
    std: tuple = IMAGENET_DEFAULT_STD,
    num_workers: int = 1,
    distributed: bool = False,
    collate_fn: Callable = None,
    pin_memory: bool = False,
    use_multi_epochs_loader: bool = False,
    persistent_workers: bool = True,
    worker_seeding: str = 'all',
    **kwargs
) -> torch.utils.data.DataLoader:
    """
    Create a DataLoader with TIMM's optimized configuration.

    Args:
        dataset: Dataset instance
        input_size: Target input size (int or tuple)
        batch_size: Batch size for training/inference
        is_training: Training mode with augmentations
        use_prefetcher: Use CUDA prefetcher for performance
        no_aug: Disable augmentations
        re_prob: Random erasing probability
        re_mode: Random erasing mode ('const', 'rand', 'pixel')
        re_count: Random erasing count
        re_num_splits: Random erasing number of splits
        scale: Random resized crop scale range
        ratio: Random resized crop aspect ratio range
        hflip: Horizontal flip probability
        vflip: Vertical flip probability
        color_jitter: Color jitter factor
        auto_augment: AutoAugment policy ('original', 'originalr', 'v0', 'v0r')
        num_aug_repeats: Number of augmentation repetitions
        num_aug_splits: Number of augmentation splits
        interpolation: Resize interpolation method
        mean: Normalization mean values
        std: Normalization standard deviation values
        num_workers: Number of data loading workers
        distributed: Enable distributed sampler
        collate_fn: Custom collate function
        pin_memory: Pin memory for GPU transfer
        use_multi_epochs_loader: Use multi-epoch loader for efficiency
        persistent_workers: Keep workers alive between epochs
        worker_seeding: Worker random seeding strategy

    Returns:
        Configured DataLoader instance
    """

Transform Creation

Comprehensive transform pipelines with support for training and inference configurations.

def create_transform(
    input_size: Union[int, tuple],
    is_training: bool = False,
    use_prefetcher: bool = False,
    no_aug: bool = False,
    scale: tuple = (0.08, 1.0),
    ratio: tuple = (3./4., 4./3.),
    hflip: float = 0.5,
    vflip: float = 0.0,
    color_jitter: float = 0.4,
    auto_augment: str = None,
    interpolation: str = 'bilinear',
    mean: tuple = IMAGENET_DEFAULT_MEAN,
    std: tuple = IMAGENET_DEFAULT_STD,
    re_prob: float = 0.0,
    re_mode: str = 'const',
    re_count: int = 1,
    re_num_splits: int = 0,
    crop_pct: float = None,
    tf_preprocessing: bool = False,
    separate: bool = False,
    **kwargs
):
    """
    Create image transform pipeline.

    Args:
        input_size: Target input size for transforms
        is_training: Use training transforms with augmentation
        use_prefetcher: Skip normalization for CUDA prefetcher
        no_aug: Disable all augmentations
        scale: Random resized crop scale range
        ratio: Random resized crop aspect ratio range
        hflip: Horizontal flip probability
        vflip: Vertical flip probability
        color_jitter: Color jitter strength
        auto_augment: AutoAugment policy name
        interpolation: Resize interpolation method
        mean: Normalization mean values
        std: Normalization standard deviation values
        re_prob: Random erasing probability
        re_mode: Random erasing mode
        re_count: Random erasing count
        re_num_splits: Random erasing splits
        crop_pct: Center crop percentage
        tf_preprocessing: Use TensorFlow-style preprocessing
        separate: Return transforms as separate list

    Returns:
        Transform function or list of transforms
    """

Dataset Creation

Factory function for creating various dataset types including ImageNet, CIFAR, and custom datasets.

def create_dataset(
    name: str,
    root: str,
    split: str = 'validation',
    is_training: bool = False,
    class_map: dict = None,
    load_bytes: bool = False,
    img_mode: str = 'RGB',
    transform: Callable = None,
    target_transform: Callable = None,
    **kwargs
):
    """
    Create dataset instance.

    Args:
        name: Dataset name or path pattern
        root: Root directory containing dataset
        split: Dataset split ('train', 'validation', 'test')
        is_training: Training mode configuration
        class_map: Custom class mapping
        load_bytes: Load images as bytes instead of PIL
        img_mode: Image mode ('RGB', 'L', etc.)
        transform: Image transforms
        target_transform: Target/label transforms
        **kwargs: Dataset-specific arguments

    Returns:
        Dataset instance
    """

Dataset Classes

Core Dataset Classes

class ImageDataset(torch.utils.data.Dataset):
    """
    Standard image dataset for classification tasks.

    Args:
        root: Root directory path
        reader: Image reader instance
        class_to_idx: Class name to index mapping
        transform: Image transforms
        target_transform: Target transforms
    """
    
    def __init__(
        self,
        root: str,
        reader: Optional[Any] = None,
        class_to_idx: Optional[dict] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None
    ): ...

class IterableImageDataset(torch.utils.data.IterableDataset):
    """
    Iterable image dataset for streaming large datasets.

    Args:
        root: Root directory or file pattern
        reader: Image reader instance
        split: Dataset split name
        is_training: Training mode
        batch_size: Batch size for iteration
        transform: Image transforms
    """
    
    def __init__(
        self,
        root: str,
        reader: Optional[Any] = None,
        split: str = 'train',
        is_training: bool = False,
        batch_size: Optional[int] = None,
        transform: Optional[Callable] = None
    ): ...

class AugMixDataset(torch.utils.data.Dataset):
    """
    Dataset wrapper for AugMix augmentation technique.

    Args:
        dataset: Base dataset
        num_splits: Number of augmentation splits
        alpha: Mixing parameter
        width: Number of augmentation chains
        depth: Depth of augmentation chains
        blended: Use blended mixing
    """
    
    def __init__(
        self,
        dataset,
        num_splits: int = 2,
        alpha: float = 1.0,
        width: int = 3,
        depth: int = -1,
        blended: bool = False
    ): ...

Transform Classes

Basic Transforms

class ToTensor:
    """Convert PIL Image to tensor."""
    
    def __call__(self, pic): ...

class ToNumpy:
    """Convert tensor to numpy array."""
    
    def __call__(self, tensor): ...

class RandomResizedCropAndInterpolation:
    """
    Random resized crop with configurable interpolation.

    Args:
        size: Target output size
        scale: Random crop scale range
        ratio: Random crop aspect ratio range
        interpolation: Interpolation method
    """
    
    def __init__(
        self,
        size: Union[int, tuple],
        scale: tuple = (0.08, 1.0),
        ratio: tuple = (3./4., 4./3.),
        interpolation: str = 'bilinear'
    ): ...

Augmentation Transforms

class RandAugment:
    """
    RandAugment augmentation.

    Args:
        ops: List of augmentation operations
        num_layers: Number of augmentation layers to apply
        magnitude: Augmentation magnitude
    """
    
    def __init__(
        self,
        ops: List[str],
        num_layers: int = 2,
        magnitude: int = 9
    ): ...

class AutoAugment:
    """
    AutoAugment data augmentation.

    Args:
        policy: AutoAugment policy name
    """
    
    def __init__(self, policy: str = 'original'): ...

class TrivialAugmentWide:
    """TrivialAugment Wide augmentation strategy."""
    
    def __init__(self): ...

class Mixup:
    """
    Mixup data augmentation.

    Args:
        mixup_alpha: Mixup interpolation coefficient
        cutmix_alpha: CutMix interpolation coefficient
        cutmix_minmax: CutMix min/max box size ratios
        prob: Probability of applying mixup/cutmix
        switch_prob: Probability of switching between mixup and cutmix
        mode: Mixup mode ('batch', 'pair', 'elem')
        correct_lam: Apply lambda correction
        label_smoothing: Label smoothing value
        num_classes: Number of classes
    """
    
    def __init__(
        self,
        mixup_alpha: float = 1.0,
        cutmix_alpha: float = 0.0,
        cutmix_minmax: Optional[tuple] = None,
        prob: float = 1.0,
        switch_prob: float = 0.5,
        mode: str = 'batch',
        correct_lam: bool = True,
        label_smoothing: float = 0.1,
        num_classes: int = 1000
    ): ...

Data Configuration

Constants and Defaults

# ImageNet normalization constants
IMAGENET_DEFAULT_MEAN: tuple = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD: tuple = (0.229, 0.224, 0.225)

# ImageNet Inception normalization
IMAGENET_INCEPTION_MEAN: tuple = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD: tuple = (0.5, 0.5, 0.5)

# OpenAI CLIP normalization
OPENAI_CLIP_MEAN: tuple = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD: tuple = (0.26862954, 0.26130258, 0.27577711)

Configuration Functions

def resolve_data_config(
    args=None,
    pretrained_cfg: dict = None,
    model: torch.nn.Module = None,
    use_test_size: bool = False,
    verbose: bool = False
) -> dict:
    """
    Resolve data configuration from model, args, or defaults.

    Args:
        args: Argument namespace with data config
        pretrained_cfg: Pretrained model configuration
        model: Model instance to extract config from
        use_test_size: Use test/inference input size
        verbose: Print resolved configuration

    Returns:
        Dictionary with resolved data configuration
    """

def resolve_model_data_config(
    model: torch.nn.Module,
    args=None,
    pretrained_cfg: dict = None,
    use_test_size: bool = False,
    verbose: bool = False
) -> dict:
    """
    Resolve data configuration specifically from model.

    Args:
        model: Model instance
        args: Additional arguments
        pretrained_cfg: Pretrained configuration override
        use_test_size: Use inference input size
        verbose: Print configuration details

    Returns:
        Model-specific data configuration
    """

Data Readers

Image Readers

def create_reader(
    name: str,
    root: str,
    split: str = 'train',
    **kwargs
):
    """
    Create image reader for different data formats.

    Args:
        name: Reader type ('', 'hfds', 'tfds', 'wds')
        root: Data root path
        split: Dataset split
        **kwargs: Reader-specific arguments

    Returns:
        Configured reader instance
    """

def get_img_extensions() -> set:
    """
    Get supported image file extensions.

    Returns:
        Set of supported extensions
    """

def is_img_extension(filename: str) -> bool:
    """
    Check if filename has supported image extension.

    Args:
        filename: File name to check

    Returns:
        True if supported image format
    """

Usage Examples

Basic Data Pipeline

import timm
from timm.data import create_loader, create_transform, create_dataset

# Create transforms for training and validation
train_transform = create_transform(
    input_size=224,
    is_training=True,
    hflip=0.5,
    color_jitter=0.4,
    auto_augment='original'
)

val_transform = create_transform(
    input_size=224,
    is_training=False
)

# Create datasets
train_dataset = create_dataset(
    'imagefolder',
    root='/path/to/train',
    transform=train_transform
)

val_dataset = create_dataset(
    'imagefolder', 
    root='/path/to/val',
    transform=val_transform
)

# Create data loaders
train_loader = create_loader(
    train_dataset,
    input_size=224,
    batch_size=32,
    is_training=True,
    num_workers=4
)

val_loader = create_loader(
    val_dataset,
    input_size=224,
    batch_size=64,
    is_training=False,
    num_workers=4
)

Advanced Augmentation

from timm.data import Mixup, create_loader

# Create mixup augmentation
mixup = Mixup(
    mixup_alpha=0.8,
    cutmix_alpha=1.0,
    prob=1.0,
    switch_prob=0.5,
    mode='batch',
    label_smoothing=0.1,
    num_classes=1000
)

# Create loader with advanced augmentation
train_loader = create_loader(
    dataset,
    input_size=224,
    batch_size=32,
    is_training=True,
    auto_augment='rand-m9-mstd0.5-inc1',
    re_prob=0.25,
    mixup_alpha=0.8,
    cutmix_alpha=1.0
)

# Apply mixup in training loop
for batch_idx, (input, target) in enumerate(train_loader):
    if mixup is not None:
        input, target = mixup(input, target)
    # ... training code

Types

from typing import Optional, Union, List, Dict, Callable, Any, Tuple
import torch

# Transform and dataset types
TransformType = Callable[[Any], torch.Tensor]
DatasetType = torch.utils.data.Dataset
LoaderType = torch.utils.data.DataLoader

# Data configuration
DataConfig = Dict[str, Any]
AugmentConfig = Dict[str, Any]

# Common type aliases
ImageSize = Union[int, Tuple[int, int]]
NormStats = Tuple[float, float, float]

Install with Tessl CLI

npx tessl i tessl/pypi-timm

docs

data.md

features.md

index.md

layers.md

models.md

training.md

utils.md

tile.json