0
# Dataset and Data Processing
1
2
Dataset abstraction layer with support for various dataset types, data transformations, sampling strategies, and data loading utilities optimized for distributed training. The system provides flexible data processing pipelines for machine learning workflows.
3
4
## Capabilities
5
6
### Base Dataset Class
7
8
Foundation class for all datasets with standardized interface and lazy loading support.
9
10
```python { .api }
11
class BaseDataset:
12
def __init__(self, ann_file: str = '', metainfo: dict = None, data_root: str = '', data_prefix: dict = None, filter_cfg: dict = None, indices: int = None, serialize_data: bool = True, pipeline: list = [], test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000):
13
"""
14
Base dataset class.
15
16
Parameters:
17
- ann_file: Annotation file path
18
- metainfo: Dataset meta information
19
- data_root: Data root directory
20
- data_prefix: Prefix for different data types
21
- filter_cfg: Config for filtering data
22
- indices: Dataset indices to use
23
- serialize_data: Whether to serialize data for faster loading
24
- pipeline: Data processing pipeline
25
- test_mode: Whether in test mode
26
- lazy_init: Whether to initialize lazily
27
- max_refetch: Maximum refetch attempts for corrupted data
28
"""
29
30
def __len__(self) -> int:
31
"""
32
Get dataset size.
33
34
Returns:
35
Dataset length
36
"""
37
38
def __getitem__(self, idx: int):
39
"""
40
Get data sample by index.
41
42
Parameters:
43
- idx: Sample index
44
45
Returns:
46
Data sample
47
"""
48
49
def get_data_info(self, idx: int) -> dict:
50
"""
51
Get data information by index.
52
53
Parameters:
54
- idx: Sample index
55
56
Returns:
57
Data information dictionary
58
"""
59
60
def prepare_data(self, idx: int) -> dict:
61
"""
62
Prepare data for processing pipeline.
63
64
Parameters:
65
- idx: Sample index
66
67
Returns:
68
Prepared data dictionary
69
"""
70
71
def load_data_list(self) -> list:
72
"""
73
Load annotation file and return data list.
74
75
Returns:
76
List of data information
77
"""
78
79
def filter_data(self) -> list:
80
"""
81
Filter data according to filter_cfg.
82
83
Returns:
84
Filtered data list
85
"""
86
87
def get_subset_(self, indices: list):
88
"""
89
Get subset of dataset.
90
91
Parameters:
92
- indices: Indices for subset
93
94
Returns:
95
Dataset subset
96
"""
97
98
@property
99
def metainfo(self) -> dict:
100
"""Get dataset meta information."""
101
102
def full_init(self):
103
"""Fully initialize dataset."""
104
```
105
106
### Data Transforms
107
108
Transform composition system for data preprocessing and augmentation.
109
110
```python { .api }
111
class Compose:
112
def __init__(self, transforms: list):
113
"""
114
Compose multiple transforms.
115
116
Parameters:
117
- transforms: List of transform configurations or instances
118
"""
119
120
def __call__(self, data: dict) -> dict:
121
"""
122
Apply transforms to data.
123
124
Parameters:
125
- data: Input data dictionary
126
127
Returns:
128
Transformed data
129
"""
130
131
def __repr__(self) -> str:
132
"""String representation of transforms."""
133
```
134
135
### Dataset Wrappers
136
137
Wrapper classes for modifying dataset behavior.
138
139
```python { .api }
140
class ClassBalancedDataset:
141
def __init__(self, dataset, oversample_thr: float = 1e-3, random_state: int = None):
142
"""
143
Dataset wrapper for class balancing through oversampling.
144
145
Parameters:
146
- dataset: Original dataset
147
- oversample_thr: Threshold for oversampling
148
- random_state: Random state for reproducibility
149
"""
150
151
def __len__(self) -> int:
152
"""Get balanced dataset length."""
153
154
def __getitem__(self, idx: int):
155
"""Get balanced sample by index."""
156
157
class ConcatDataset:
158
def __init__(self, datasets: list):
159
"""
160
Concatenate multiple datasets.
161
162
Parameters:
163
- datasets: List of datasets to concatenate
164
"""
165
166
def __len__(self) -> int:
167
"""Get total length of concatenated datasets."""
168
169
def __getitem__(self, idx: int):
170
"""Get sample from appropriate dataset."""
171
172
def get_dataset_idx_and_sample_idx(self, idx: int) -> tuple:
173
"""
174
Get dataset index and sample index.
175
176
Parameters:
177
- idx: Global index
178
179
Returns:
180
Tuple of (dataset_idx, sample_idx)
181
"""
182
183
class RepeatDataset:
184
def __init__(self, dataset, times: int):
185
"""
186
Repeat dataset multiple times.
187
188
Parameters:
189
- dataset: Original dataset
190
- times: Number of repetitions
191
"""
192
193
def __len__(self) -> int:
194
"""Get repeated dataset length."""
195
196
def __getitem__(self, idx: int):
197
"""Get sample from repeated dataset."""
198
```
199
200
### Data Samplers
201
202
Sampling strategies for data loading in different training scenarios.
203
204
```python { .api }
205
class DefaultSampler:
206
def __init__(self, dataset, shuffle: bool = True, seed: int = None, round_up: bool = True):
207
"""
208
Default data sampler.
209
210
Parameters:
211
- dataset: Dataset to sample from
212
- shuffle: Whether to shuffle data
213
- seed: Random seed
214
- round_up: Whether to round up dataset size
215
"""
216
217
def __iter__(self):
218
"""Iterator over sample indices."""
219
220
def __len__(self) -> int:
221
"""Get number of samples."""
222
223
class InfiniteSampler:
224
def __init__(self, dataset, shuffle: bool = True, seed: int = None):
225
"""
226
Infinite data sampler for continuous sampling.
227
228
Parameters:
229
- dataset: Dataset to sample from
230
- shuffle: Whether to shuffle data
231
- seed: Random seed
232
"""
233
234
def __iter__(self):
235
"""Infinite iterator over sample indices."""
236
237
def __len__(self) -> int:
238
"""Get dataset length."""
239
240
def set_epoch(self, epoch: int):
241
"""
242
Set epoch for sampling.
243
244
Parameters:
245
- epoch: Current epoch
246
"""
247
```
248
249
### Data Loading Utilities
250
251
Utility functions for data loading and processing.
252
253
```python { .api }
254
def force_full_init(dataset):
255
"""
256
Force full initialization of dataset.
257
258
Parameters:
259
- dataset: Dataset to initialize
260
"""
261
262
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
263
"""
264
Worker initialization function for DataLoader.
265
266
Parameters:
267
- worker_id: Worker ID
268
- num_workers: Total number of workers
269
- rank: Process rank
270
- seed: Random seed
271
"""
272
273
def pseudo_collate(batch: list) -> list:
274
"""
275
Pseudo collate function that doesn't actually collate.
276
277
Parameters:
278
- batch: List of samples
279
280
Returns:
281
Original batch list
282
"""
283
284
def default_collate(batch: list):
285
"""
286
Default collate function for batching data.
287
288
Parameters:
289
- batch: List of samples
290
291
Returns:
292
Collated batch
293
"""
294
```
295
296
### Collate Functions
297
298
Registry of available collate functions for different data types.
299
300
```python { .api }
301
COLLATE_FUNCTIONS: dict # Dictionary mapping names to collate functions
302
```
303
304
## Usage Examples
305
306
### Basic Dataset Implementation
307
308
```python
309
from mmengine.dataset import BaseDataset
310
import json
311
import os
312
313
class CustomDataset(BaseDataset):
314
def __init__(self, ann_file, data_root, **kwargs):
315
self.data_root = data_root
316
super().__init__(ann_file=ann_file, **kwargs)
317
318
def load_data_list(self):
319
"""Load annotation file."""
320
with open(self.ann_file, 'r') as f:
321
data_list = json.load(f)
322
323
# Process annotations
324
for data_info in data_list:
325
data_info['img_path'] = os.path.join(
326
self.data_root, data_info['filename']
327
)
328
329
return data_list
330
331
def prepare_data(self, idx):
332
"""Prepare data for pipeline."""
333
data_info = self.get_data_info(idx)
334
return {
335
'img_path': data_info['img_path'],
336
'gt_label': data_info['label'],
337
'sample_idx': idx
338
}
339
340
# Usage
341
dataset = CustomDataset(
342
ann_file='annotations.json',
343
data_root='data/',
344
pipeline=[
345
dict(type='LoadImageFromFile'),
346
dict(type='Resize', scale=(224, 224)),
347
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
348
dict(type='PackInputs')
349
]
350
)
351
```
352
353
### Data Pipeline Configuration
354
355
```python
356
from mmengine.dataset import Compose
357
358
# Define data pipeline
359
train_pipeline = [
360
dict(type='LoadImageFromFile'),
361
dict(type='RandomResizedCrop', scale=224),
362
dict(type='RandomFlip', prob=0.5),
363
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
364
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
365
dict(type='PackInputs')
366
]
367
368
val_pipeline = [
369
dict(type='LoadImageFromFile'),
370
dict(type='Resize', scale=256),
371
dict(type='CenterCrop', size=224),
372
dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
373
dict(type='PackInputs')
374
]
375
376
# Create transform compositions
377
train_transforms = Compose(train_pipeline)
378
val_transforms = Compose(val_pipeline)
379
380
# Apply to datasets
381
train_dataset = CustomDataset(ann_file='train.json', pipeline=train_pipeline)
382
val_dataset = CustomDataset(ann_file='val.json', pipeline=val_pipeline)
383
```
384
385
### Dataset Wrappers Usage
386
387
```python
388
from mmengine.dataset import ClassBalancedDataset, ConcatDataset, RepeatDataset
389
390
# Class balancing for imbalanced datasets
391
balanced_dataset = ClassBalancedDataset(
392
dataset=train_dataset,
393
oversample_thr=1e-3
394
)
395
396
# Concatenate multiple datasets
397
combined_dataset = ConcatDataset([
398
dataset1,
399
dataset2,
400
dataset3
401
])
402
403
# Repeat dataset for more training data
404
repeated_dataset = RepeatDataset(
405
dataset=small_dataset,
406
times=10
407
)
408
```
409
410
### Custom Sampler Implementation
411
412
```python
413
from mmengine.dataset import DefaultSampler
414
import torch.utils.data as data
415
416
# Create sampler
417
sampler = DefaultSampler(
418
dataset=train_dataset,
419
shuffle=True,
420
seed=42,
421
round_up=True
422
)
423
424
# Use with DataLoader
425
dataloader = data.DataLoader(
426
dataset=train_dataset,
427
batch_size=32,
428
sampler=sampler,
429
collate_fn=default_collate,
430
worker_init_fn=lambda worker_id: worker_init_fn(
431
worker_id, num_workers=4, rank=0, seed=42
432
)
433
)
434
```
435
436
### Distributed Data Loading
437
438
```python
439
from torch.utils.data.distributed import DistributedSampler
440
from mmengine.dataset import force_full_init
441
442
# Force full dataset initialization for distributed training
443
force_full_init(dataset)
444
445
# Create distributed sampler
446
sampler = DistributedSampler(
447
dataset=dataset,
448
shuffle=True,
449
seed=42
450
)
451
452
# DataLoader for distributed training
453
dataloader = data.DataLoader(
454
dataset=dataset,
455
batch_size=32,
456
sampler=sampler,
457
num_workers=4,
458
pin_memory=True,
459
worker_init_fn=lambda worker_id: worker_init_fn(
460
worker_id, num_workers=4, rank=get_rank(), seed=42
461
)
462
)
463
```
464
465
### Infinite Sampling for Continuous Training
466
467
```python
468
from mmengine.dataset import InfiniteSampler
469
470
# Create infinite sampler
471
infinite_sampler = InfiniteSampler(
472
dataset=dataset,
473
shuffle=True,
474
seed=42
475
)
476
477
# Use for continuous training
478
dataloader = data.DataLoader(
479
dataset=dataset,
480
batch_size=32,
481
sampler=infinite_sampler
482
)
483
484
# Training loop with infinite data
485
for epoch in range(num_epochs):
486
infinite_sampler.set_epoch(epoch)
487
for i, batch in enumerate(dataloader):
488
if i >= steps_per_epoch:
489
break
490
# Training step
491
train_step(batch)
492
```
493
494
### Custom Collate Function
495
496
```python
497
def custom_collate(batch):
498
"""Custom collate function for special data types."""
499
images = []
500
labels = []
501
metadata = []
502
503
for sample in batch:
504
images.append(sample['image'])
505
labels.append(sample['label'])
506
metadata.append(sample['metadata'])
507
508
return {
509
'images': torch.stack(images),
510
'labels': torch.tensor(labels),
511
'metadata': metadata
512
}
513
514
# Register custom collate function
515
COLLATE_FUNCTIONS['custom_collate'] = custom_collate
516
517
# Use in dataset configuration
518
dataset_cfg = dict(
519
type='CustomDataset',
520
collate_fn='custom_collate',
521
# ... other configs
522
)
523
```
524
525
### Dataset Filtering
526
527
```python
528
class FilteredDataset(BaseDataset):
529
def __init__(self, min_size=32, **kwargs):
530
self.min_size = min_size
531
super().__init__(**kwargs)
532
533
def filter_data(self):
534
"""Filter out samples smaller than min_size."""
535
valid_data_infos = []
536
for data_info in self.data_list:
537
if data_info.get('width', 0) >= self.min_size and \
538
data_info.get('height', 0) >= self.min_size:
539
valid_data_infos.append(data_info)
540
return valid_data_infos
541
542
# Usage
543
filtered_dataset = FilteredDataset(
544
ann_file='annotations.json',
545
min_size=64,
546
filter_cfg=dict(filter_empty_gt=True)
547
)
548
```