or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

configuration.mddataset.mddistributed.mdfileio.mdindex.mdlogging.mdmodels.mdoptimization.mdregistry.mdtraining.mdvisualization.md

distributed.mddocs/

0

# Distributed Training

1

2

Multi-GPU and multi-node training support with various distribution strategies including DDP, FSDP, DeepSpeed, and ColossalAI integration with communication utilities and device management. The system provides comprehensive distributed training capabilities for scalable deep learning.

3

4

## Capabilities

5

6

### Distributed Initialization

7

8

Functions for initializing distributed training environments.

9

10

```python { .api }

11

def init_dist(launcher: str, backend: str = 'nccl', **kwargs):

12

"""

13

Initialize distributed training.

14

15

Parameters:

16

- launcher: Launcher type ('pytorch', 'mpi', 'slurm')

17

- backend: Communication backend ('nccl', 'gloo', 'mpi')

18

- **kwargs: Additional initialization arguments

19

"""

20

21

def init_local_group(group_size: int):

22

"""

23

Initialize local process group.

24

25

Parameters:

26

- group_size: Size of local group

27

"""

28

29

def get_backend() -> str:

30

"""

31

Get current distributed backend.

32

33

Returns:

34

Backend name

35

"""

36

37

def infer_launcher() -> str:

38

"""

39

Infer distributed launcher from environment.

40

41

Returns:

42

Inferred launcher type

43

"""

44

```

45

46

### Process Information

47

48

Functions for getting information about distributed processes.

49

50

```python { .api }

51

def get_dist_info() -> tuple:

52

"""

53

Get distributed training information.

54

55

Returns:

56

Tuple of (rank, world_size)

57

"""

58

59

def get_rank() -> int:

60

"""

61

Get current process rank.

62

63

Returns:

64

Process rank

65

"""

66

67

def get_world_size() -> int:

68

"""

69

Get total number of processes.

70

71

Returns:

72

World size

73

"""

74

75

def get_local_rank() -> int:

76

"""

77

Get local rank within node.

78

79

Returns:

80

Local rank

81

"""

82

83

def get_local_size() -> int:

84

"""

85

Get local group size.

86

87

Returns:

88

Local group size

89

"""

90

91

def get_local_group():

92

"""

93

Get local process group.

94

95

Returns:

96

Local process group

97

"""

98

99

def is_main_process() -> bool:

100

"""

101

Check if current process is main process.

102

103

Returns:

104

True if main process

105

"""

106

107

def is_distributed() -> bool:

108

"""

109

Check if in distributed mode.

110

111

Returns:

112

True if distributed training is enabled

113

"""

114

115

def get_default_group():

116

"""

117

Get default process group.

118

119

Returns:

120

Default process group

121

"""

122

```

123

124

### Communication Functions

125

126

Functions for inter-process communication in distributed training.

127

128

```python { .api }

129

def all_reduce(tensor, op: str = 'sum', group=None, async_op: bool = False):

130

"""

131

All-reduce operation across processes.

132

133

Parameters:

134

- tensor: Tensor to reduce

135

- op: Reduction operation ('sum', 'mean', 'max', 'min')

136

- group: Process group

137

- async_op: Whether to perform asynchronously

138

"""

139

140

def all_gather(tensor_list: list, tensor, group=None, async_op: bool = False):

141

"""

142

All-gather operation across processes.

143

144

Parameters:

145

- tensor_list: List to store gathered tensors

146

- tensor: Tensor to gather

147

- group: Process group

148

- async_op: Whether to perform asynchronously

149

"""

150

151

def all_gather_object(object_list: list, obj, group=None):

152

"""

153

All-gather Python objects across processes.

154

155

Parameters:

156

- object_list: List to store gathered objects

157

- obj: Object to gather

158

- group: Process group

159

"""

160

161

def broadcast(tensor, src: int = 0, group=None, async_op: bool = False):

162

"""

163

Broadcast tensor from source process.

164

165

Parameters:

166

- tensor: Tensor to broadcast

167

- src: Source process rank

168

- group: Process group

169

- async_op: Whether to perform asynchronously

170

"""

171

172

def broadcast_object_list(object_list: list, src: int = 0, group=None):

173

"""

174

Broadcast list of objects from source process.

175

176

Parameters:

177

- object_list: List of objects to broadcast

178

- src: Source process rank

179

- group: Process group

180

"""

181

182

def gather(tensor, gather_list: list = None, dst: int = 0, group=None, async_op: bool = False):

183

"""

184

Gather tensors to destination process.

185

186

Parameters:

187

- tensor: Tensor to gather

188

- gather_list: List to store gathered tensors

189

- dst: Destination process rank

190

- group: Process group

191

- async_op: Whether to perform asynchronously

192

"""

193

194

def gather_object(obj, object_gather_list: list = None, dst: int = 0, group=None):

195

"""

196

Gather Python objects to destination process.

197

198

Parameters:

199

- obj: Object to gather

200

- object_gather_list: List to store gathered objects

201

- dst: Destination process rank

202

- group: Process group

203

"""

204

205

def reduce(tensor, dst: int = 0, op: str = 'sum', group=None, async_op: bool = False):

206

"""

207

Reduce tensor to destination process.

208

209

Parameters:

210

- tensor: Tensor to reduce

211

- dst: Destination process rank

212

- op: Reduction operation

213

- group: Process group

214

- async_op: Whether to perform asynchronously

215

"""

216

217

def barrier(group=None, async_op: bool = False):

218

"""

219

Synchronization barrier across processes.

220

221

Parameters:

222

- group: Process group

223

- async_op: Whether to perform asynchronously

224

"""

225

226

def sync_random_seed(seed: int = None, device: str = 'cuda') -> int:

227

"""

228

Synchronize random seed across processes.

229

230

Parameters:

231

- seed: Random seed (generated if None)

232

- device: Device for synchronization

233

234

Returns:

235

Synchronized seed

236

"""

237

```

238

239

### Advanced Communication

240

241

Higher-level communication functions for complex operations.

242

243

```python { .api }

244

def all_reduce_dict(py_dict: dict, op: str = 'mean', group=None, to_float: bool = True) -> dict:

245

"""

246

All-reduce dictionary of tensors.

247

248

Parameters:

249

- py_dict: Dictionary of tensors

250

- op: Reduction operation

251

- group: Process group

252

- to_float: Whether to convert to float

253

254

Returns:

255

Reduced dictionary

256

"""

257

258

def all_reduce_params(params, coalesce: bool = True, bucket_size_mb: int = -1):

259

"""

260

All-reduce model parameters.

261

262

Parameters:

263

- params: Model parameters

264

- coalesce: Whether to coalesce parameters

265

- bucket_size_mb: Bucket size in MB

266

"""

267

268

def collect_results(result_part: list, size: int, tmpdir: str = None) -> list:

269

"""

270

Collect results from all processes.

271

272

Parameters:

273

- result_part: Partial results from current process

274

- size: Total size of dataset

275

- tmpdir: Temporary directory for file-based collection

276

277

Returns:

278

Collected results from all processes

279

"""

280

281

def collect_results_cpu(result_part: list, size: int, tmpdir: str = None) -> list:

282

"""

283

Collect results to CPU from all processes.

284

285

Parameters:

286

- result_part: Partial results

287

- size: Total dataset size

288

- tmpdir: Temporary directory

289

290

Returns:

291

CPU results from all processes

292

"""

293

294

def collect_results_gpu(result_part: list, size: int) -> list:

295

"""

296

Collect results on GPU from all processes.

297

298

Parameters:

299

- result_part: Partial results

300

- size: Total dataset size

301

302

Returns:

303

GPU results from all processes

304

"""

305

```

306

307

### Device Management

308

309

Functions for managing devices in distributed environments.

310

311

```python { .api }

312

def get_device() -> str:

313

"""

314

Get current device.

315

316

Returns:

317

Device string ('cuda:0', 'cpu', etc.)

318

"""

319

320

def get_data_device(data) -> str:

321

"""

322

Get device of data.

323

324

Parameters:

325

- data: Input data (tensor, dict, list, etc.)

326

327

Returns:

328

Device string

329

"""

330

331

def get_comm_device(group=None) -> str:

332

"""

333

Get communication device for process group.

334

335

Parameters:

336

- group: Process group

337

338

Returns:

339

Communication device

340

"""

341

342

def cast_data_device(data, device: str, out=None):

343

"""

344

Cast data to specified device.

345

346

Parameters:

347

- data: Input data

348

- device: Target device

349

- out: Output container

350

351

Returns:

352

Data on target device

353

"""

354

```

355

356

### Model Wrappers

357

358

Distributed data parallel wrappers for models.

359

360

```python { .api }

361

class MMDistributedDataParallel:

362

def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False, bucket_cap_mb: int = 25, gradient_as_bucket_view: bool = False):

363

"""

364

MMEngine's distributed data parallel wrapper.

365

366

Parameters:

367

- module: Model module to wrap

368

- device_ids: Device IDs for this process

369

- output_device: Output device ID

370

- broadcast_buffers: Whether to broadcast buffers

371

- find_unused_parameters: Whether to find unused parameters

372

- bucket_cap_mb: Bucket capacity in MB

373

- gradient_as_bucket_view: Whether to use gradient bucket view

374

"""

375

376

def forward(self, *inputs, **kwargs):

377

"""

378

Forward pass with gradient synchronization.

379

380

Parameters:

381

- *inputs: Input arguments

382

- **kwargs: Input keyword arguments

383

384

Returns:

385

Model outputs

386

"""

387

388

class MMSeparateDistributedDataParallel:

389

def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False):

390

"""

391

Separate distributed data parallel for different parameter groups.

392

393

Parameters:

394

- module: Model module

395

- device_ids: Device IDs

396

- output_device: Output device

397

- broadcast_buffers: Whether to broadcast buffers

398

- find_unused_parameters: Whether to find unused parameters

399

"""

400

401

class MMFullyShardedDataParallel:

402

def __init__(self, module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id: int = None, sync_module_states: bool = False, forward_prefetch: bool = False, limit_all_gathers: bool = True, use_orig_params: bool = False):

403

"""

404

Fully sharded data parallel wrapper (PyTorch >=2.0).

405

406

Parameters:

407

- module: Model module

408

- process_group: Process group

409

- sharding_strategy: Sharding strategy

410

- cpu_offload: CPU offload configuration

411

- auto_wrap_policy: Auto-wrap policy

412

- backward_prefetch: Backward prefetch strategy

413

- mixed_precision: Mixed precision policy

414

- ignored_modules: Modules to ignore

415

- param_init_fn: Parameter initialization function

416

- device_id: Device ID

417

- sync_module_states: Whether to sync module states

418

- forward_prefetch: Whether to prefetch in forward

419

- limit_all_gathers: Whether to limit all-gathers

420

- use_orig_params: Whether to use original parameters

421

"""

422

423

def is_model_wrapper(model) -> bool:

424

"""

425

Check if model is wrapped with distributed wrapper.

426

427

Parameters:

428

- model: Model to check

429

430

Returns:

431

True if model is wrapped

432

"""

433

```

434

435

### Utility Decorators

436

437

Decorators for distributed training utilities.

438

439

```python { .api }

440

def master_only(func):

441

"""

442

Decorator to run function only on master process.

443

444

Parameters:

445

- func: Function to decorate

446

447

Returns:

448

Decorated function

449

"""

450

```

451

452

## Usage Examples

453

454

### Basic Distributed Training Setup

455

456

```python

457

import torch

458

from mmengine import Runner, init_dist

459

460

# Initialize distributed training

461

init_dist('pytorch', backend='nccl')

462

463

# Get distributed info

464

rank, world_size = get_dist_info()

465

local_rank = get_local_rank()

466

467

# Set device

468

torch.cuda.set_device(local_rank)

469

device = torch.device('cuda', local_rank)

470

471

# Create model and move to device

472

model = MyModel().to(device)

473

474

# Wrap with DDP

475

from mmengine.model import MMDistributedDataParallel

476

model = MMDistributedDataParallel(

477

model,

478

device_ids=[local_rank],

479

broadcast_buffers=False,

480

find_unused_parameters=False

481

)

482

483

# Create runner with distributed configuration

484

runner = Runner(

485

model=model,

486

work_dir='./work_dir',

487

train_dataloader=train_loader,

488

launcher='pytorch'

489

)

490

491

runner.train()

492

```

493

494

### Communication Examples

495

496

```python

497

import torch

498

from mmengine.dist import all_reduce, all_gather, broadcast

499

500

# All-reduce operation

501

loss = torch.tensor(0.5).cuda()

502

all_reduce(loss, op='mean') # Average loss across all processes

503

504

# All-gather operation

505

local_tensor = torch.randn(4).cuda()

506

gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(get_world_size())]

507

all_gather(gathered_tensors, local_tensor)

508

509

# Broadcast operation

510

if get_rank() == 0:

511

data = torch.randn(10).cuda()

512

else:

513

data = torch.zeros(10).cuda()

514

broadcast(data, src=0)

515

516

# Dictionary all-reduce

517

metrics = {'loss': torch.tensor(0.5), 'acc': torch.tensor(0.9)}

518

reduced_metrics = all_reduce_dict(metrics, op='mean')

519

```

520

521

### Result Collection

522

523

```python

524

from mmengine.dist import collect_results

525

526

# Collect evaluation results from all processes

527

def evaluate_model(model, dataloader):

528

results = []

529

for batch in dataloader:

530

outputs = model(batch)

531

results.extend(outputs)

532

533

# Collect results from all processes

534

all_results = collect_results(results, len(dataloader.dataset))

535

536

# Only compute metrics on main process

537

if is_main_process():

538

metrics = compute_metrics(all_results)

539

return metrics

540

return {}

541

```

542

543

### Master-Only Operations

544

545

```python

546

from mmengine.dist import master_only, is_main_process

547

548

@master_only

549

def save_checkpoint(model, path):

550

"""Save checkpoint only on master process."""

551

torch.save(model.state_dict(), path)

552

553

@master_only

554

def log_metrics(metrics):

555

"""Log metrics only on master process."""

556

print(f"Metrics: {metrics}")

557

558

# Alternative approach

559

def training_step(model, data):

560

loss = model(data)

561

562

if is_main_process():

563

print(f"Loss: {loss.item()}")

564

565

return loss

566

```

567

568

### Advanced DDP Configuration

569

570

```python

571

from mmengine.model import MMDistributedDataParallel

572

573

# DDP with gradient bucketing and unused parameter detection

574

model = MMDistributedDataParallel(

575

model,

576

device_ids=[local_rank],

577

output_device=local_rank,

578

broadcast_buffers=True,

579

find_unused_parameters=True,

580

bucket_cap_mb=25,

581

gradient_as_bucket_view=True

582

)

583

584

# Separate DDP for models with different parameter update frequencies

585

model = MMSeparateDistributedDataParallel(

586

model,

587

device_ids=[local_rank],

588

find_unused_parameters=True

589

)

590

```

591

592

### FSDP Usage (PyTorch >=2.0)

593

594

```python

595

from mmengine.model import MMFullyShardedDataParallel

596

from torch.distributed.fsdp import ShardingStrategy, CPUOffload

597

598

# FSDP configuration

599

model = MMFullyShardedDataParallel(

600

model,

601

sharding_strategy=ShardingStrategy.FULL_SHARD,

602

cpu_offload=CPUOffload(offload_params=True),

603

mixed_precision=None,

604

backward_prefetch=None,

605

forward_prefetch=False,

606

limit_all_gathers=True

607

)

608

```

609

610

### Random Seed Synchronization

611

612

```python

613

from mmengine.dist import sync_random_seed

614

615

# Synchronize random seed across all processes

616

seed = sync_random_seed(42)

617

618

# Use synchronized seed

619

torch.manual_seed(seed)

620

np.random.seed(seed)

621

```