Engine of OpenMMLab projects for training deep learning models based on PyTorch with large-scale training frameworks, configuration management, and monitoring capabilities
—
Dataset abstraction layer with support for various dataset types, data transformations, sampling strategies, and data loading utilities optimized for distributed training. The system provides flexible data processing pipelines for machine learning workflows.
Foundation class for all datasets with standardized interface and lazy loading support.
class BaseDataset:
def __init__(self, ann_file: str = '', metainfo: dict = None, data_root: str = '', data_prefix: dict = None, filter_cfg: dict = None, indices: int = None, serialize_data: bool = True, pipeline: list = [], test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000):
"""
Base dataset class.
Parameters:
- ann_file: Annotation file path
- metainfo: Dataset meta information
- data_root: Data root directory
- data_prefix: Prefix for different data types
- filter_cfg: Config for filtering data
- indices: Dataset indices to use
- serialize_data: Whether to serialize data for faster loading
- pipeline: Data processing pipeline
- test_mode: Whether in test mode
- lazy_init: Whether to initialize lazily
- max_refetch: Maximum refetch attempts for corrupted data
"""
def __len__(self) -> int:
"""
Get dataset size.
Returns:
Dataset length
"""
def __getitem__(self, idx: int):
"""
Get data sample by index.
Parameters:
- idx: Sample index
Returns:
Data sample
"""
def get_data_info(self, idx: int) -> dict:
"""
Get data information by index.
Parameters:
- idx: Sample index
Returns:
Data information dictionary
"""
def prepare_data(self, idx: int) -> dict:
"""
Prepare data for processing pipeline.
Parameters:
- idx: Sample index
Returns:
Prepared data dictionary
"""
def load_data_list(self) -> list:
"""
Load annotation file and return data list.
Returns:
List of data information
"""
def filter_data(self) -> list:
"""
Filter data according to filter_cfg.
Returns:
Filtered data list
"""
def get_subset_(self, indices: list):
"""
Get subset of dataset.
Parameters:
- indices: Indices for subset
Returns:
Dataset subset
"""
@property
def metainfo(self) -> dict:
"""Get dataset meta information."""
def full_init(self):
"""Fully initialize dataset."""Transform composition system for data preprocessing and augmentation.
class Compose:
def __init__(self, transforms: list):
"""
Compose multiple transforms.
Parameters:
- transforms: List of transform configurations or instances
"""
def __call__(self, data: dict) -> dict:
"""
Apply transforms to data.
Parameters:
- data: Input data dictionary
Returns:
Transformed data
"""
def __repr__(self) -> str:
"""String representation of transforms."""Wrapper classes for modifying dataset behavior.
class ClassBalancedDataset:
def __init__(self, dataset, oversample_thr: float = 1e-3, random_state: int = None):
"""
Dataset wrapper for class balancing through oversampling.
Parameters:
- dataset: Original dataset
- oversample_thr: Threshold for oversampling
- random_state: Random state for reproducibility
"""
def __len__(self) -> int:
"""Get balanced dataset length."""
def __getitem__(self, idx: int):
"""Get balanced sample by index."""
class ConcatDataset:
def __init__(self, datasets: list):
"""
Concatenate multiple datasets.
Parameters:
- datasets: List of datasets to concatenate
"""
def __len__(self) -> int:
"""Get total length of concatenated datasets."""
def __getitem__(self, idx: int):
"""Get sample from appropriate dataset."""
def get_dataset_idx_and_sample_idx(self, idx: int) -> tuple:
"""
Get dataset index and sample index.
Parameters:
- idx: Global index
Returns:
Tuple of (dataset_idx, sample_idx)
"""
class RepeatDataset:
def __init__(self, dataset, times: int):
"""
Repeat dataset multiple times.
Parameters:
- dataset: Original dataset
- times: Number of repetitions
"""
def __len__(self) -> int:
"""Get repeated dataset length."""
def __getitem__(self, idx: int):
"""Get sample from repeated dataset."""Sampling strategies for data loading in different training scenarios.
class DefaultSampler:
def __init__(self, dataset, shuffle: bool = True, seed: int = None, round_up: bool = True):
"""
Default data sampler.
Parameters:
- dataset: Dataset to sample from
- shuffle: Whether to shuffle data
- seed: Random seed
- round_up: Whether to round up dataset size
"""
def __iter__(self):
"""Iterator over sample indices."""
def __len__(self) -> int:
"""Get number of samples."""
class InfiniteSampler:
def __init__(self, dataset, shuffle: bool = True, seed: int = None):
"""
Infinite data sampler for continuous sampling.
Parameters:
- dataset: Dataset to sample from
- shuffle: Whether to shuffle data
- seed: Random seed
"""
def __iter__(self):
"""Infinite iterator over sample indices."""
def __len__(self) -> int:
"""Get dataset length."""
def set_epoch(self, epoch: int):
"""
Set epoch for sampling.
Parameters:
- epoch: Current epoch
"""Utility functions for data loading and processing.
def force_full_init(dataset):
"""
Force full initialization of dataset.
Parameters:
- dataset: Dataset to initialize
"""
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
"""
Worker initialization function for DataLoader.
Parameters:
- worker_id: Worker ID
- num_workers: Total number of workers
- rank: Process rank
- seed: Random seed
"""
def pseudo_collate(batch: list) -> list:
"""
Pseudo collate function that doesn't actually collate.
Parameters:
- batch: List of samples
Returns:
Original batch list
"""
def default_collate(batch: list):
"""
Default collate function for batching data.
Parameters:
- batch: List of samples
Returns:
Collated batch
"""Registry of available collate functions for different data types.
COLLATE_FUNCTIONS: dict # Dictionary mapping names to collate functionsfrom mmengine.dataset import BaseDataset
import json
import os
class CustomDataset(BaseDataset):
def __init__(self, ann_file, data_root, **kwargs):
self.data_root = data_root
super().__init__(ann_file=ann_file, **kwargs)
def load_data_list(self):
"""Load annotation file."""
with open(self.ann_file, 'r') as f:
data_list = json.load(f)
# Process annotations
for data_info in data_list:
data_info['img_path'] = os.path.join(
self.data_root, data_info['filename']
)
return data_list
def prepare_data(self, idx):
"""Prepare data for pipeline."""
data_info = self.get_data_info(idx)
return {
'img_path': data_info['img_path'],
'gt_label': data_info['label'],
'sample_idx': idx
}
# Usage
dataset = CustomDataset(
ann_file='annotations.json',
data_root='data/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(224, 224)),
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
dict(type='PackInputs')
]
)from mmengine.dataset import Compose
# Define data pipeline
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5),
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
dict(type='PackInputs')
]
val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=256),
dict(type='CenterCrop', size=224),
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
dict(type='PackInputs')
]
# Create transform compositions
train_transforms = Compose(train_pipeline)
val_transforms = Compose(val_pipeline)
# Apply to datasets
train_dataset = CustomDataset(ann_file='train.json', pipeline=train_pipeline)
val_dataset = CustomDataset(ann_file='val.json', pipeline=val_pipeline)from mmengine.dataset import ClassBalancedDataset, ConcatDataset, RepeatDataset
# Class balancing for imbalanced datasets
balanced_dataset = ClassBalancedDataset(
dataset=train_dataset,
oversample_thr=1e-3
)
# Concatenate multiple datasets
combined_dataset = ConcatDataset([
dataset1,
dataset2,
dataset3
])
# Repeat dataset for more training data
repeated_dataset = RepeatDataset(
dataset=small_dataset,
times=10
)from mmengine.dataset import DefaultSampler
import torch.utils.data as data
# Create sampler
sampler = DefaultSampler(
dataset=train_dataset,
shuffle=True,
seed=42,
round_up=True
)
# Use with DataLoader
dataloader = data.DataLoader(
dataset=train_dataset,
batch_size=32,
sampler=sampler,
collate_fn=default_collate,
worker_init_fn=lambda worker_id: worker_init_fn(
worker_id, num_workers=4, rank=0, seed=42
)
)from torch.utils.data.distributed import DistributedSampler
from mmengine.dataset import force_full_init
# Force full dataset initialization for distributed training
force_full_init(dataset)
# Create distributed sampler
sampler = DistributedSampler(
dataset=dataset,
shuffle=True,
seed=42
)
# DataLoader for distributed training
dataloader = data.DataLoader(
dataset=dataset,
batch_size=32,
sampler=sampler,
num_workers=4,
pin_memory=True,
worker_init_fn=lambda worker_id: worker_init_fn(
worker_id, num_workers=4, rank=get_rank(), seed=42
)
)from mmengine.dataset import InfiniteSampler
# Create infinite sampler
infinite_sampler = InfiniteSampler(
dataset=dataset,
shuffle=True,
seed=42
)
# Use for continuous training
dataloader = data.DataLoader(
dataset=dataset,
batch_size=32,
sampler=infinite_sampler
)
# Training loop with infinite data
for epoch in range(num_epochs):
infinite_sampler.set_epoch(epoch)
for i, batch in enumerate(dataloader):
if i >= steps_per_epoch:
break
# Training step
train_step(batch)def custom_collate(batch):
"""Custom collate function for special data types."""
images = []
labels = []
metadata = []
for sample in batch:
images.append(sample['image'])
labels.append(sample['label'])
metadata.append(sample['metadata'])
return {
'images': torch.stack(images),
'labels': torch.tensor(labels),
'metadata': metadata
}
# Register custom collate function
COLLATE_FUNCTIONS['custom_collate'] = custom_collate
# Use in dataset configuration
dataset_cfg = dict(
type='CustomDataset',
collate_fn='custom_collate',
# ... other configs
)class FilteredDataset(BaseDataset):
def __init__(self, min_size=32, **kwargs):
self.min_size = min_size
super().__init__(**kwargs)
def filter_data(self):
"""Filter out samples smaller than min_size."""
valid_data_infos = []
for data_info in self.data_list:
if data_info.get('width', 0) >= self.min_size and \
data_info.get('height', 0) >= self.min_size:
valid_data_infos.append(data_info)
return valid_data_infos
# Usage
filtered_dataset = FilteredDataset(
ann_file='annotations.json',
min_size=64,
filter_cfg=dict(filter_empty_gt=True)
)Install with Tessl CLI
npx tessl i tessl/pypi-mmengine