or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

configuration.mddataset.mddistributed.mdfileio.mdindex.mdlogging.mdmodels.mdoptimization.mdregistry.mdtraining.mdvisualization.md

dataset.mddocs/

0

# Dataset and Data Processing

1

2

Dataset abstraction layer with support for various dataset types, data transformations, sampling strategies, and data loading utilities optimized for distributed training. The system provides flexible data processing pipelines for machine learning workflows.

3

4

## Capabilities

5

6

### Base Dataset Class

7

8

Foundation class for all datasets with standardized interface and lazy loading support.

9

10

```python { .api }

11

class BaseDataset:

12

def __init__(self, ann_file: str = '', metainfo: dict = None, data_root: str = '', data_prefix: dict = None, filter_cfg: dict = None, indices: int = None, serialize_data: bool = True, pipeline: list = [], test_mode: bool = False, lazy_init: bool = False, max_refetch: int = 1000):

13

"""

14

Base dataset class.

15

16

Parameters:

17

- ann_file: Annotation file path

18

- metainfo: Dataset meta information

19

- data_root: Data root directory

20

- data_prefix: Prefix for different data types

21

- filter_cfg: Config for filtering data

22

- indices: Dataset indices to use

23

- serialize_data: Whether to serialize data for faster loading

24

- pipeline: Data processing pipeline

25

- test_mode: Whether in test mode

26

- lazy_init: Whether to initialize lazily

27

- max_refetch: Maximum refetch attempts for corrupted data

28

"""

29

30

def __len__(self) -> int:

31

"""

32

Get dataset size.

33

34

Returns:

35

Dataset length

36

"""

37

38

def __getitem__(self, idx: int):

39

"""

40

Get data sample by index.

41

42

Parameters:

43

- idx: Sample index

44

45

Returns:

46

Data sample

47

"""

48

49

def get_data_info(self, idx: int) -> dict:

50

"""

51

Get data information by index.

52

53

Parameters:

54

- idx: Sample index

55

56

Returns:

57

Data information dictionary

58

"""

59

60

def prepare_data(self, idx: int) -> dict:

61

"""

62

Prepare data for processing pipeline.

63

64

Parameters:

65

- idx: Sample index

66

67

Returns:

68

Prepared data dictionary

69

"""

70

71

def load_data_list(self) -> list:

72

"""

73

Load annotation file and return data list.

74

75

Returns:

76

List of data information

77

"""

78

79

def filter_data(self) -> list:

80

"""

81

Filter data according to filter_cfg.

82

83

Returns:

84

Filtered data list

85

"""

86

87

def get_subset_(self, indices: list):

88

"""

89

Get subset of dataset.

90

91

Parameters:

92

- indices: Indices for subset

93

94

Returns:

95

Dataset subset

96

"""

97

98

@property

99

def metainfo(self) -> dict:

100

"""Get dataset meta information."""

101

102

def full_init(self):

103

"""Fully initialize dataset."""

104

```

105

106

### Data Transforms

107

108

Transform composition system for data preprocessing and augmentation.

109

110

```python { .api }

111

class Compose:

112

def __init__(self, transforms: list):

113

"""

114

Compose multiple transforms.

115

116

Parameters:

117

- transforms: List of transform configurations or instances

118

"""

119

120

def __call__(self, data: dict) -> dict:

121

"""

122

Apply transforms to data.

123

124

Parameters:

125

- data: Input data dictionary

126

127

Returns:

128

Transformed data

129

"""

130

131

def __repr__(self) -> str:

132

"""String representation of transforms."""

133

```

134

135

### Dataset Wrappers

136

137

Wrapper classes for modifying dataset behavior.

138

139

```python { .api }

140

class ClassBalancedDataset:

141

def __init__(self, dataset, oversample_thr: float = 1e-3, random_state: int = None):

142

"""

143

Dataset wrapper for class balancing through oversampling.

144

145

Parameters:

146

- dataset: Original dataset

147

- oversample_thr: Threshold for oversampling

148

- random_state: Random state for reproducibility

149

"""

150

151

def __len__(self) -> int:

152

"""Get balanced dataset length."""

153

154

def __getitem__(self, idx: int):

155

"""Get balanced sample by index."""

156

157

class ConcatDataset:

158

def __init__(self, datasets: list):

159

"""

160

Concatenate multiple datasets.

161

162

Parameters:

163

- datasets: List of datasets to concatenate

164

"""

165

166

def __len__(self) -> int:

167

"""Get total length of concatenated datasets."""

168

169

def __getitem__(self, idx: int):

170

"""Get sample from appropriate dataset."""

171

172

def get_dataset_idx_and_sample_idx(self, idx: int) -> tuple:

173

"""

174

Get dataset index and sample index.

175

176

Parameters:

177

- idx: Global index

178

179

Returns:

180

Tuple of (dataset_idx, sample_idx)

181

"""

182

183

class RepeatDataset:

184

def __init__(self, dataset, times: int):

185

"""

186

Repeat dataset multiple times.

187

188

Parameters:

189

- dataset: Original dataset

190

- times: Number of repetitions

191

"""

192

193

def __len__(self) -> int:

194

"""Get repeated dataset length."""

195

196

def __getitem__(self, idx: int):

197

"""Get sample from repeated dataset."""

198

```

199

200

### Data Samplers

201

202

Sampling strategies for data loading in different training scenarios.

203

204

```python { .api }

205

class DefaultSampler:

206

def __init__(self, dataset, shuffle: bool = True, seed: int = None, round_up: bool = True):

207

"""

208

Default data sampler.

209

210

Parameters:

211

- dataset: Dataset to sample from

212

- shuffle: Whether to shuffle data

213

- seed: Random seed

214

- round_up: Whether to round up dataset size

215

"""

216

217

def __iter__(self):

218

"""Iterator over sample indices."""

219

220

def __len__(self) -> int:

221

"""Get number of samples."""

222

223

class InfiniteSampler:

224

def __init__(self, dataset, shuffle: bool = True, seed: int = None):

225

"""

226

Infinite data sampler for continuous sampling.

227

228

Parameters:

229

- dataset: Dataset to sample from

230

- shuffle: Whether to shuffle data

231

- seed: Random seed

232

"""

233

234

def __iter__(self):

235

"""Infinite iterator over sample indices."""

236

237

def __len__(self) -> int:

238

"""Get dataset length."""

239

240

def set_epoch(self, epoch: int):

241

"""

242

Set epoch for sampling.

243

244

Parameters:

245

- epoch: Current epoch

246

"""

247

```

248

249

### Data Loading Utilities

250

251

Utility functions for data loading and processing.

252

253

```python { .api }

254

def force_full_init(dataset):

255

"""

256

Force full initialization of dataset.

257

258

Parameters:

259

- dataset: Dataset to initialize

260

"""

261

262

def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):

263

"""

264

Worker initialization function for DataLoader.

265

266

Parameters:

267

- worker_id: Worker ID

268

- num_workers: Total number of workers

269

- rank: Process rank

270

- seed: Random seed

271

"""

272

273

def pseudo_collate(batch: list) -> list:

274

"""

275

Pseudo collate function that doesn't actually collate.

276

277

Parameters:

278

- batch: List of samples

279

280

Returns:

281

Original batch list

282

"""

283

284

def default_collate(batch: list):

285

"""

286

Default collate function for batching data.

287

288

Parameters:

289

- batch: List of samples

290

291

Returns:

292

Collated batch

293

"""

294

```

295

296

### Collate Functions

297

298

Registry of available collate functions for different data types.

299

300

```python { .api }

301

COLLATE_FUNCTIONS: dict # Dictionary mapping names to collate functions

302

```

303

304

## Usage Examples

305

306

### Basic Dataset Implementation

307

308

```python

309

from mmengine.dataset import BaseDataset

310

import json

311

import os

312

313

class CustomDataset(BaseDataset):

314

def __init__(self, ann_file, data_root, **kwargs):

315

self.data_root = data_root

316

super().__init__(ann_file=ann_file, **kwargs)

317

318

def load_data_list(self):

319

"""Load annotation file."""

320

with open(self.ann_file, 'r') as f:

321

data_list = json.load(f)

322

323

# Process annotations

324

for data_info in data_list:

325

data_info['img_path'] = os.path.join(

326

self.data_root, data_info['filename']

327

)

328

329

return data_list

330

331

def prepare_data(self, idx):

332

"""Prepare data for pipeline."""

333

data_info = self.get_data_info(idx)

334

return {

335

'img_path': data_info['img_path'],

336

'gt_label': data_info['label'],

337

'sample_idx': idx

338

}

339

340

# Usage

341

dataset = CustomDataset(

342

ann_file='annotations.json',

343

data_root='data/',

344

pipeline=[

345

dict(type='LoadImageFromFile'),

346

dict(type='Resize', scale=(224, 224)),

347

dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

348

dict(type='PackInputs')

349

]

350

)

351

```

352

353

### Data Pipeline Configuration

354

355

```python

356

from mmengine.dataset import Compose

357

358

# Define data pipeline

359

train_pipeline = [

360

dict(type='LoadImageFromFile'),

361

dict(type='RandomResizedCrop', scale=224),

362

dict(type='RandomFlip', prob=0.5),

363

dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),

364

dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

365

dict(type='PackInputs')

366

]

367

368

val_pipeline = [

369

dict(type='LoadImageFromFile'),

370

dict(type='Resize', scale=256),

371

dict(type='CenterCrop', size=224),

372

dict(type='Normalize', mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

373

dict(type='PackInputs')

374

]

375

376

# Create transform compositions

377

train_transforms = Compose(train_pipeline)

378

val_transforms = Compose(val_pipeline)

379

380

# Apply to datasets

381

train_dataset = CustomDataset(ann_file='train.json', pipeline=train_pipeline)

382

val_dataset = CustomDataset(ann_file='val.json', pipeline=val_pipeline)

383

```

384

385

### Dataset Wrappers Usage

386

387

```python

388

from mmengine.dataset import ClassBalancedDataset, ConcatDataset, RepeatDataset

389

390

# Class balancing for imbalanced datasets

391

balanced_dataset = ClassBalancedDataset(

392

dataset=train_dataset,

393

oversample_thr=1e-3

394

)

395

396

# Concatenate multiple datasets

397

combined_dataset = ConcatDataset([

398

dataset1,

399

dataset2,

400

dataset3

401

])

402

403

# Repeat dataset for more training data

404

repeated_dataset = RepeatDataset(

405

dataset=small_dataset,

406

times=10

407

)

408

```

409

410

### Custom Sampler Implementation

411

412

```python

413

from mmengine.dataset import DefaultSampler

414

import torch.utils.data as data

415

416

# Create sampler

417

sampler = DefaultSampler(

418

dataset=train_dataset,

419

shuffle=True,

420

seed=42,

421

round_up=True

422

)

423

424

# Use with DataLoader

425

dataloader = data.DataLoader(

426

dataset=train_dataset,

427

batch_size=32,

428

sampler=sampler,

429

collate_fn=default_collate,

430

worker_init_fn=lambda worker_id: worker_init_fn(

431

worker_id, num_workers=4, rank=0, seed=42

432

)

433

)

434

```

435

436

### Distributed Data Loading

437

438

```python

439

from torch.utils.data.distributed import DistributedSampler

440

from mmengine.dataset import force_full_init

441

442

# Force full dataset initialization for distributed training

443

force_full_init(dataset)

444

445

# Create distributed sampler

446

sampler = DistributedSampler(

447

dataset=dataset,

448

shuffle=True,

449

seed=42

450

)

451

452

# DataLoader for distributed training

453

dataloader = data.DataLoader(

454

dataset=dataset,

455

batch_size=32,

456

sampler=sampler,

457

num_workers=4,

458

pin_memory=True,

459

worker_init_fn=lambda worker_id: worker_init_fn(

460

worker_id, num_workers=4, rank=get_rank(), seed=42

461

)

462

)

463

```

464

465

### Infinite Sampling for Continuous Training

466

467

```python

468

from mmengine.dataset import InfiniteSampler

469

470

# Create infinite sampler

471

infinite_sampler = InfiniteSampler(

472

dataset=dataset,

473

shuffle=True,

474

seed=42

475

)

476

477

# Use for continuous training

478

dataloader = data.DataLoader(

479

dataset=dataset,

480

batch_size=32,

481

sampler=infinite_sampler

482

)

483

484

# Training loop with infinite data

485

for epoch in range(num_epochs):

486

infinite_sampler.set_epoch(epoch)

487

for i, batch in enumerate(dataloader):

488

if i >= steps_per_epoch:

489

break

490

# Training step

491

train_step(batch)

492

```

493

494

### Custom Collate Function

495

496

```python

497

def custom_collate(batch):

498

"""Custom collate function for special data types."""

499

images = []

500

labels = []

501

metadata = []

502

503

for sample in batch:

504

images.append(sample['image'])

505

labels.append(sample['label'])

506

metadata.append(sample['metadata'])

507

508

return {

509

'images': torch.stack(images),

510

'labels': torch.tensor(labels),

511

'metadata': metadata

512

}

513

514

# Register custom collate function

515

COLLATE_FUNCTIONS['custom_collate'] = custom_collate

516

517

# Use in dataset configuration

518

dataset_cfg = dict(

519

type='CustomDataset',

520

collate_fn='custom_collate',

521

# ... other configs

522

)

523

```

524

525

### Dataset Filtering

526

527

```python

528

class FilteredDataset(BaseDataset):

529

def __init__(self, min_size=32, **kwargs):

530

self.min_size = min_size

531

super().__init__(**kwargs)

532

533

def filter_data(self):

534

"""Filter out samples smaller than min_size."""

535

valid_data_infos = []

536

for data_info in self.data_list:

537

if data_info.get('width', 0) >= self.min_size and \

538

data_info.get('height', 0) >= self.min_size:

539

valid_data_infos.append(data_info)

540

return valid_data_infos

541

542

# Usage

543

filtered_dataset = FilteredDataset(

544

ann_file='annotations.json',

545

min_size=64,

546

filter_cfg=dict(filter_empty_gt=True)

547

)

548

```