CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-mmengine

Engine of OpenMMLab projects for training deep learning models based on PyTorch with large-scale training frameworks, configuration management, and monitoring capabilities

Pending
Overview
Eval results
Files

dataset.mddocs/

Dataset and Data Processing

Dataset abstraction layer with support for various dataset types, data transformations, sampling strategies, and data loading utilities optimized for distributed training. The system provides flexible data processing pipelines for machine learning workflows.

Capabilities

Base Dataset Class

Foundation class for all datasets with standardized interface and lazy loading support.

class BaseDataset:
    def __init__(self, ann_file: str = '', metainfo: dict = None, data_root: str = '', data_prefix: dict = None, filter_cfg: dict = None, indices: int = None, serialize_data: bool = True, pipeline: list = [], test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000):
        """
        Base dataset class.
        
        Parameters:
        - ann_file: Annotation file path
        - metainfo: Dataset meta information
        - data_root: Data root directory
        - data_prefix: Prefix for different data types
        - filter_cfg: Config for filtering data
        - indices: Dataset indices to use
        - serialize_data: Whether to serialize data for faster loading
        - pipeline: Data processing pipeline
        - test_mode: Whether in test mode
        - lazy_init: Whether to initialize lazily
        - max_refetch: Maximum refetch attempts for corrupted data
        """

    def __len__(self) -> int:
        """
        Get dataset size.
        
        Returns:
        Dataset length
        """

    def __getitem__(self, idx: int):
        """
        Get data sample by index.
        
        Parameters:
        - idx: Sample index
        
        Returns:
        Data sample
        """

    def get_data_info(self, idx: int) -> dict:
        """
        Get data information by index.
        
        Parameters:
        - idx: Sample index
        
        Returns:
        Data information dictionary
        """

    def prepare_data(self, idx: int) -> dict:
        """
        Prepare data for processing pipeline.
        
        Parameters:
        - idx: Sample index
        
        Returns:
        Prepared data dictionary
        """

    def load_data_list(self) -> list:
        """
        Load annotation file and return data list.
        
        Returns:
        List of data information
        """

    def filter_data(self) -> list:
        """
        Filter data according to filter_cfg.
        
        Returns:
        Filtered data list
        """

    def get_subset_(self, indices: list):
        """
        Get subset of dataset.
        
        Parameters:
        - indices: Indices for subset
        
        Returns:
        Dataset subset
        """

    @property
    def metainfo(self) -> dict:
        """Get dataset meta information."""

    def full_init(self):
        """Fully initialize dataset."""

Data Transforms

Transform composition system for data preprocessing and augmentation.

class Compose:
    def __init__(self, transforms: list):
        """
        Compose multiple transforms.
        
        Parameters:
        - transforms: List of transform configurations or instances
        """

    def __call__(self, data: dict) -> dict:
        """
        Apply transforms to data.
        
        Parameters:
        - data: Input data dictionary
        
        Returns:
        Transformed data
        """

    def __repr__(self) -> str:
        """String representation of transforms."""

Dataset Wrappers

Wrapper classes for modifying dataset behavior.

class ClassBalancedDataset:
    def __init__(self, dataset, oversample_thr: float = 1e-3, random_state: int = None):
        """
        Dataset wrapper for class balancing through oversampling.
        
        Parameters:
        - dataset: Original dataset
        - oversample_thr: Threshold for oversampling
        - random_state: Random state for reproducibility
        """

    def __len__(self) -> int:
        """Get balanced dataset length."""

    def __getitem__(self, idx: int):
        """Get balanced sample by index."""

class ConcatDataset:
    def __init__(self, datasets: list):
        """
        Concatenate multiple datasets.
        
        Parameters:
        - datasets: List of datasets to concatenate
        """

    def __len__(self) -> int:
        """Get total length of concatenated datasets."""

    def __getitem__(self, idx: int):
        """Get sample from appropriate dataset."""

    def get_dataset_idx_and_sample_idx(self, idx: int) -> tuple:
        """
        Get dataset index and sample index.
        
        Parameters:
        - idx: Global index
        
        Returns:
        Tuple of (dataset_idx, sample_idx)
        """

class RepeatDataset:
    def __init__(self, dataset, times: int):
        """
        Repeat dataset multiple times.
        
        Parameters:
        - dataset: Original dataset
        - times: Number of repetitions
        """

    def __len__(self) -> int:
        """Get repeated dataset length."""

    def __getitem__(self, idx: int):
        """Get sample from repeated dataset."""

Data Samplers

Sampling strategies for data loading in different training scenarios.

class DefaultSampler:
    def __init__(self, dataset, shuffle: bool = True, seed: int = None, round_up: bool = True):
        """
        Default data sampler.
        
        Parameters:
        - dataset: Dataset to sample from
        - shuffle: Whether to shuffle data
        - seed: Random seed
        - round_up: Whether to round up dataset size
        """

    def __iter__(self):
        """Iterator over sample indices."""

    def __len__(self) -> int:
        """Get number of samples."""

class InfiniteSampler:
    def __init__(self, dataset, shuffle: bool = True, seed: int = None):
        """
        Infinite data sampler for continuous sampling.
        
        Parameters:
        - dataset: Dataset to sample from
        - shuffle: Whether to shuffle data
        - seed: Random seed
        """

    def __iter__(self):
        """Infinite iterator over sample indices."""

    def __len__(self) -> int:
        """Get dataset length."""

    def set_epoch(self, epoch: int):
        """
        Set epoch for sampling.
        
        Parameters:
        - epoch: Current epoch
        """

Data Loading Utilities

Utility functions for data loading and processing.

def force_full_init(dataset):
    """
    Force full initialization of dataset.
    
    Parameters:
    - dataset: Dataset to initialize
    """

def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
    """
    Worker initialization function for DataLoader.
    
    Parameters:
    - worker_id: Worker ID
    - num_workers: Total number of workers
    - rank: Process rank
    - seed: Random seed
    """

def pseudo_collate(batch: list) -> list:
    """
    Pseudo collate function that doesn't actually collate.
    
    Parameters:
    - batch: List of samples
    
    Returns:
    Original batch list
    """

def default_collate(batch: list):
    """
    Default collate function for batching data.
    
    Parameters:
    - batch: List of samples
    
    Returns:
    Collated batch
    """

Collate Functions

Registry of available collate functions for different data types.

COLLATE_FUNCTIONS: dict  # Dictionary mapping names to collate functions

Usage Examples

Basic Dataset Implementation

from mmengine.dataset import BaseDataset
import json
import os

class CustomDataset(BaseDataset):
    def __init__(self, ann_file, data_root, **kwargs):
        self.data_root = data_root
        super().__init__(ann_file=ann_file, **kwargs)
    
    def load_data_list(self):
        """Load annotation file."""
        with open(self.ann_file, 'r') as f:
            data_list = json.load(f)
        
        # Process annotations
        for data_info in data_list:
            data_info['img_path'] = os.path.join(
                self.data_root, data_info['filename']
            )
        
        return data_list
    
    def prepare_data(self, idx):
        """Prepare data for pipeline."""
        data_info = self.get_data_info(idx)
        return {
            'img_path': data_info['img_path'],
            'gt_label': data_info['label'],
            'sample_idx': idx
        }

# Usage
dataset = CustomDataset(
    ann_file='annotations.json',
    data_root='data/',
    pipeline=[
        dict(type='LoadImageFromFile'),
        dict(type='Resize', scale=(224, 224)),
        dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        dict(type='PackInputs')
    ]
)

Data Pipeline Configuration

from mmengine.dataset import Compose

# Define data pipeline
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', scale=224),
    dict(type='RandomFlip', prob=0.5),
    dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
    dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    dict(type='PackInputs')
]

val_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=256),
    dict(type='CenterCrop', size=224),
    dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    dict(type='PackInputs')
]

# Create transform compositions
train_transforms = Compose(train_pipeline)
val_transforms = Compose(val_pipeline)

# Apply to datasets
train_dataset = CustomDataset(ann_file='train.json', pipeline=train_pipeline)
val_dataset = CustomDataset(ann_file='val.json', pipeline=val_pipeline)

Dataset Wrappers Usage

from mmengine.dataset import ClassBalancedDataset, ConcatDataset, RepeatDataset

# Class balancing for imbalanced datasets
balanced_dataset = ClassBalancedDataset(
    dataset=train_dataset,
    oversample_thr=1e-3
)

# Concatenate multiple datasets
combined_dataset = ConcatDataset([
    dataset1,
    dataset2,
    dataset3
])

# Repeat dataset for more training data
repeated_dataset = RepeatDataset(
    dataset=small_dataset,
    times=10
)

Custom Sampler Implementation

from mmengine.dataset import DefaultSampler
import torch.utils.data as data

# Create sampler
sampler = DefaultSampler(
    dataset=train_dataset,
    shuffle=True,
    seed=42,
    round_up=True
)

# Use with DataLoader
dataloader = data.DataLoader(
    dataset=train_dataset,
    batch_size=32,
    sampler=sampler,
    collate_fn=default_collate,
    worker_init_fn=lambda worker_id: worker_init_fn(
        worker_id, num_workers=4, rank=0, seed=42
    )
)

Distributed Data Loading

from torch.utils.data.distributed import DistributedSampler
from mmengine.dataset import force_full_init

# Force full dataset initialization for distributed training
force_full_init(dataset)

# Create distributed sampler
sampler = DistributedSampler(
    dataset=dataset,
    shuffle=True,
    seed=42
)

# DataLoader for distributed training
dataloader = data.DataLoader(
    dataset=dataset,
    batch_size=32,
    sampler=sampler,
    num_workers=4,
    pin_memory=True,
    worker_init_fn=lambda worker_id: worker_init_fn(
        worker_id, num_workers=4, rank=get_rank(), seed=42
    )
)

Infinite Sampling for Continuous Training

from mmengine.dataset import InfiniteSampler

# Create infinite sampler
infinite_sampler = InfiniteSampler(
    dataset=dataset,
    shuffle=True,
    seed=42
)

# Use for continuous training
dataloader = data.DataLoader(
    dataset=dataset,
    batch_size=32,
    sampler=infinite_sampler
)

# Training loop with infinite data
for epoch in range(num_epochs):
    infinite_sampler.set_epoch(epoch)
    for i, batch in enumerate(dataloader):
        if i >= steps_per_epoch:
            break
        # Training step
        train_step(batch)

Custom Collate Function

def custom_collate(batch):
    """Custom collate function for special data types."""
    images = []
    labels = []
    metadata = []
    
    for sample in batch:
        images.append(sample['image'])
        labels.append(sample['label'])
        metadata.append(sample['metadata'])
    
    return {
        'images': torch.stack(images),
        'labels': torch.tensor(labels),
        'metadata': metadata
    }

# Register custom collate function
COLLATE_FUNCTIONS['custom_collate'] = custom_collate

# Use in dataset configuration
dataset_cfg = dict(
    type='CustomDataset',
    collate_fn='custom_collate',
    # ... other configs
)

Dataset Filtering

class FilteredDataset(BaseDataset):
    def __init__(self, min_size=32, **kwargs):
        self.min_size = min_size
        super().__init__(**kwargs)
    
    def filter_data(self):
        """Filter out samples smaller than min_size."""
        valid_data_infos = []
        for data_info in self.data_list:
            if data_info.get('width', 0) >= self.min_size and \
               data_info.get('height', 0) >= self.min_size:
                valid_data_infos.append(data_info)
        return valid_data_infos

# Usage
filtered_dataset = FilteredDataset(
    ann_file='annotations.json',
    min_size=64,
    filter_cfg=dict(filter_empty_gt=True)
)

Install with Tessl CLI

npx tessl i tessl/pypi-mmengine

docs

configuration.md

dataset.md

distributed.md

fileio.md

index.md

logging.md

models.md

optimization.md

registry.md

training.md

visualization.md

tile.json