0
# Distributed Training
1
2
Multi-GPU and multi-node training support with various distribution strategies including DDP, FSDP, DeepSpeed, and ColossalAI integration with communication utilities and device management. The system provides comprehensive distributed training capabilities for scalable deep learning.
3
4
## Capabilities
5
6
### Distributed Initialization
7
8
Functions for initializing distributed training environments.
9
10
```python { .api }
11
def init_dist(launcher: str, backend: str = 'nccl', **kwargs):
12
"""
13
Initialize distributed training.
14
15
Parameters:
16
- launcher: Launcher type ('pytorch', 'mpi', 'slurm')
17
- backend: Communication backend ('nccl', 'gloo', 'mpi')
18
- **kwargs: Additional initialization arguments
19
"""
20
21
def init_local_group(group_size: int):
22
"""
23
Initialize local process group.
24
25
Parameters:
26
- group_size: Size of local group
27
"""
28
29
def get_backend() -> str:
30
"""
31
Get current distributed backend.
32
33
Returns:
34
Backend name
35
"""
36
37
def infer_launcher() -> str:
38
"""
39
Infer distributed launcher from environment.
40
41
Returns:
42
Inferred launcher type
43
"""
44
```
45
46
### Process Information
47
48
Functions for getting information about distributed processes.
49
50
```python { .api }
51
def get_dist_info() -> tuple:
52
"""
53
Get distributed training information.
54
55
Returns:
56
Tuple of (rank, world_size)
57
"""
58
59
def get_rank() -> int:
60
"""
61
Get current process rank.
62
63
Returns:
64
Process rank
65
"""
66
67
def get_world_size() -> int:
68
"""
69
Get total number of processes.
70
71
Returns:
72
World size
73
"""
74
75
def get_local_rank() -> int:
76
"""
77
Get local rank within node.
78
79
Returns:
80
Local rank
81
"""
82
83
def get_local_size() -> int:
84
"""
85
Get local group size.
86
87
Returns:
88
Local group size
89
"""
90
91
def get_local_group():
92
"""
93
Get local process group.
94
95
Returns:
96
Local process group
97
"""
98
99
def is_main_process() -> bool:
100
"""
101
Check if current process is main process.
102
103
Returns:
104
True if main process
105
"""
106
107
def is_distributed() -> bool:
108
"""
109
Check if in distributed mode.
110
111
Returns:
112
True if distributed training is enabled
113
"""
114
115
def get_default_group():
116
"""
117
Get default process group.
118
119
Returns:
120
Default process group
121
"""
122
```
123
124
### Communication Functions
125
126
Functions for inter-process communication in distributed training.
127
128
```python { .api }
129
def all_reduce(tensor, op: str = 'sum', group=None, async_op: bool = False):
130
"""
131
All-reduce operation across processes.
132
133
Parameters:
134
- tensor: Tensor to reduce
135
- op: Reduction operation ('sum', 'mean', 'max', 'min')
136
- group: Process group
137
- async_op: Whether to perform asynchronously
138
"""
139
140
def all_gather(tensor_list: list, tensor, group=None, async_op: bool = False):
141
"""
142
All-gather operation across processes.
143
144
Parameters:
145
- tensor_list: List to store gathered tensors
146
- tensor: Tensor to gather
147
- group: Process group
148
- async_op: Whether to perform asynchronously
149
"""
150
151
def all_gather_object(object_list: list, obj, group=None):
152
"""
153
All-gather Python objects across processes.
154
155
Parameters:
156
- object_list: List to store gathered objects
157
- obj: Object to gather
158
- group: Process group
159
"""
160
161
def broadcast(tensor, src: int = 0, group=None, async_op: bool = False):
162
"""
163
Broadcast tensor from source process.
164
165
Parameters:
166
- tensor: Tensor to broadcast
167
- src: Source process rank
168
- group: Process group
169
- async_op: Whether to perform asynchronously
170
"""
171
172
def broadcast_object_list(object_list: list, src: int = 0, group=None):
173
"""
174
Broadcast list of objects from source process.
175
176
Parameters:
177
- object_list: List of objects to broadcast
178
- src: Source process rank
179
- group: Process group
180
"""
181
182
def gather(tensor, gather_list: list = None, dst: int = 0, group=None, async_op: bool = False):
183
"""
184
Gather tensors to destination process.
185
186
Parameters:
187
- tensor: Tensor to gather
188
- gather_list: List to store gathered tensors
189
- dst: Destination process rank
190
- group: Process group
191
- async_op: Whether to perform asynchronously
192
"""
193
194
def gather_object(obj, object_gather_list: list = None, dst: int = 0, group=None):
195
"""
196
Gather Python objects to destination process.
197
198
Parameters:
199
- obj: Object to gather
200
- object_gather_list: List to store gathered objects
201
- dst: Destination process rank
202
- group: Process group
203
"""
204
205
def reduce(tensor, dst: int = 0, op: str = 'sum', group=None, async_op: bool = False):
206
"""
207
Reduce tensor to destination process.
208
209
Parameters:
210
- tensor: Tensor to reduce
211
- dst: Destination process rank
212
- op: Reduction operation
213
- group: Process group
214
- async_op: Whether to perform asynchronously
215
"""
216
217
def barrier(group=None, async_op: bool = False):
218
"""
219
Synchronization barrier across processes.
220
221
Parameters:
222
- group: Process group
223
- async_op: Whether to perform asynchronously
224
"""
225
226
def sync_random_seed(seed: int = None, device: str = 'cuda') -> int:
227
"""
228
Synchronize random seed across processes.
229
230
Parameters:
231
- seed: Random seed (generated if None)
232
- device: Device for synchronization
233
234
Returns:
235
Synchronized seed
236
"""
237
```
238
239
### Advanced Communication
240
241
Higher-level communication functions for complex operations.
242
243
```python { .api }
244
def all_reduce_dict(py_dict: dict, op: str = 'mean', group=None, to_float: bool = True) -> dict:
245
"""
246
All-reduce dictionary of tensors.
247
248
Parameters:
249
- py_dict: Dictionary of tensors
250
- op: Reduction operation
251
- group: Process group
252
- to_float: Whether to convert to float
253
254
Returns:
255
Reduced dictionary
256
"""
257
258
def all_reduce_params(params, coalesce: bool = True, bucket_size_mb: int = -1):
259
"""
260
All-reduce model parameters.
261
262
Parameters:
263
- params: Model parameters
264
- coalesce: Whether to coalesce parameters
265
- bucket_size_mb: Bucket size in MB
266
"""
267
268
def collect_results(result_part: list, size: int, tmpdir: str = None) -> list:
269
"""
270
Collect results from all processes.
271
272
Parameters:
273
- result_part: Partial results from current process
274
- size: Total size of dataset
275
- tmpdir: Temporary directory for file-based collection
276
277
Returns:
278
Collected results from all processes
279
"""
280
281
def collect_results_cpu(result_part: list, size: int, tmpdir: str = None) -> list:
282
"""
283
Collect results to CPU from all processes.
284
285
Parameters:
286
- result_part: Partial results
287
- size: Total dataset size
288
- tmpdir: Temporary directory
289
290
Returns:
291
CPU results from all processes
292
"""
293
294
def collect_results_gpu(result_part: list, size: int) -> list:
295
"""
296
Collect results on GPU from all processes.
297
298
Parameters:
299
- result_part: Partial results
300
- size: Total dataset size
301
302
Returns:
303
GPU results from all processes
304
"""
305
```
306
307
### Device Management
308
309
Functions for managing devices in distributed environments.
310
311
```python { .api }
312
def get_device() -> str:
313
"""
314
Get current device.
315
316
Returns:
317
Device string ('cuda:0', 'cpu', etc.)
318
"""
319
320
def get_data_device(data) -> str:
321
"""
322
Get device of data.
323
324
Parameters:
325
- data: Input data (tensor, dict, list, etc.)
326
327
Returns:
328
Device string
329
"""
330
331
def get_comm_device(group=None) -> str:
332
"""
333
Get communication device for process group.
334
335
Parameters:
336
- group: Process group
337
338
Returns:
339
Communication device
340
"""
341
342
def cast_data_device(data, device: str, out=None):
343
"""
344
Cast data to specified device.
345
346
Parameters:
347
- data: Input data
348
- device: Target device
349
- out: Output container
350
351
Returns:
352
Data on target device
353
"""
354
```
355
356
### Model Wrappers
357
358
Distributed data parallel wrappers for models.
359
360
```python { .api }
361
class MMDistributedDataParallel:
362
def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False, bucket_cap_mb: int = 25, gradient_as_bucket_view: bool = False):
363
"""
364
MMEngine's distributed data parallel wrapper.
365
366
Parameters:
367
- module: Model module to wrap
368
- device_ids: Device IDs for this process
369
- output_device: Output device ID
370
- broadcast_buffers: Whether to broadcast buffers
371
- find_unused_parameters: Whether to find unused parameters
372
- bucket_cap_mb: Bucket capacity in MB
373
- gradient_as_bucket_view: Whether to use gradient bucket view
374
"""
375
376
def forward(self, *inputs, **kwargs):
377
"""
378
Forward pass with gradient synchronization.
379
380
Parameters:
381
- *inputs: Input arguments
382
- **kwargs: Input keyword arguments
383
384
Returns:
385
Model outputs
386
"""
387
388
class MMSeparateDistributedDataParallel:
389
def __init__(self, module, device_ids: list = None, output_device: int = None, broadcast_buffers: bool = True, find_unused_parameters: bool = False):
390
"""
391
Separate distributed data parallel for different parameter groups.
392
393
Parameters:
394
- module: Model module
395
- device_ids: Device IDs
396
- output_device: Output device
397
- broadcast_buffers: Whether to broadcast buffers
398
- find_unused_parameters: Whether to find unused parameters
399
"""
400
401
class MMFullyShardedDataParallel:
402
def __init__(self, module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=None, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id: int = None, sync_module_states: bool = False, forward_prefetch: bool = False, limit_all_gathers: bool = True, use_orig_params: bool = False):
403
"""
404
Fully sharded data parallel wrapper (PyTorch >=2.0).
405
406
Parameters:
407
- module: Model module
408
- process_group: Process group
409
- sharding_strategy: Sharding strategy
410
- cpu_offload: CPU offload configuration
411
- auto_wrap_policy: Auto-wrap policy
412
- backward_prefetch: Backward prefetch strategy
413
- mixed_precision: Mixed precision policy
414
- ignored_modules: Modules to ignore
415
- param_init_fn: Parameter initialization function
416
- device_id: Device ID
417
- sync_module_states: Whether to sync module states
418
- forward_prefetch: Whether to prefetch in forward
419
- limit_all_gathers: Whether to limit all-gathers
420
- use_orig_params: Whether to use original parameters
421
"""
422
423
def is_model_wrapper(model) -> bool:
424
"""
425
Check if model is wrapped with distributed wrapper.
426
427
Parameters:
428
- model: Model to check
429
430
Returns:
431
True if model is wrapped
432
"""
433
```
434
435
### Utility Decorators
436
437
Decorators for distributed training utilities.
438
439
```python { .api }
440
def master_only(func):
441
"""
442
Decorator to run function only on master process.
443
444
Parameters:
445
- func: Function to decorate
446
447
Returns:
448
Decorated function
449
"""
450
```
451
452
## Usage Examples
453
454
### Basic Distributed Training Setup
455
456
```python
457
import torch
458
from mmengine import Runner, init_dist
459
460
# Initialize distributed training
461
init_dist('pytorch', backend='nccl')
462
463
# Get distributed info
464
rank, world_size = get_dist_info()
465
local_rank = get_local_rank()
466
467
# Set device
468
torch.cuda.set_device(local_rank)
469
device = torch.device('cuda', local_rank)
470
471
# Create model and move to device
472
model = MyModel().to(device)
473
474
# Wrap with DDP
475
from mmengine.model import MMDistributedDataParallel
476
model = MMDistributedDataParallel(
477
model,
478
device_ids=[local_rank],
479
broadcast_buffers=False,
480
find_unused_parameters=False
481
)
482
483
# Create runner with distributed configuration
484
runner = Runner(
485
model=model,
486
work_dir='./work_dir',
487
train_dataloader=train_loader,
488
launcher='pytorch'
489
)
490
491
runner.train()
492
```
493
494
### Communication Examples
495
496
```python
497
import torch
498
from mmengine.dist import all_reduce, all_gather, broadcast
499
500
# All-reduce operation
501
loss = torch.tensor(0.5).cuda()
502
all_reduce(loss, op='mean') # Average loss across all processes
503
504
# All-gather operation
505
local_tensor = torch.randn(4).cuda()
506
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(get_world_size())]
507
all_gather(gathered_tensors, local_tensor)
508
509
# Broadcast operation
510
if get_rank() == 0:
511
data = torch.randn(10).cuda()
512
else:
513
data = torch.zeros(10).cuda()
514
broadcast(data, src=0)
515
516
# Dictionary all-reduce
517
metrics = {'loss': torch.tensor(0.5), 'acc': torch.tensor(0.9)}
518
reduced_metrics = all_reduce_dict(metrics, op='mean')
519
```
520
521
### Result Collection
522
523
```python
524
from mmengine.dist import collect_results
525
526
# Collect evaluation results from all processes
527
def evaluate_model(model, dataloader):
528
results = []
529
for batch in dataloader:
530
outputs = model(batch)
531
results.extend(outputs)
532
533
# Collect results from all processes
534
all_results = collect_results(results, len(dataloader.dataset))
535
536
# Only compute metrics on main process
537
if is_main_process():
538
metrics = compute_metrics(all_results)
539
return metrics
540
return {}
541
```
542
543
### Master-Only Operations
544
545
```python
546
from mmengine.dist import master_only, is_main_process
547
548
@master_only
549
def save_checkpoint(model, path):
550
"""Save checkpoint only on master process."""
551
torch.save(model.state_dict(), path)
552
553
@master_only
554
def log_metrics(metrics):
555
"""Log metrics only on master process."""
556
print(f"Metrics: {metrics}")
557
558
# Alternative approach
559
def training_step(model, data):
560
loss = model(data)
561
562
if is_main_process():
563
print(f"Loss: {loss.item()}")
564
565
return loss
566
```
567
568
### Advanced DDP Configuration
569
570
```python
571
from mmengine.model import MMDistributedDataParallel
572
573
# DDP with gradient bucketing and unused parameter detection
574
model = MMDistributedDataParallel(
575
model,
576
device_ids=[local_rank],
577
output_device=local_rank,
578
broadcast_buffers=True,
579
find_unused_parameters=True,
580
bucket_cap_mb=25,
581
gradient_as_bucket_view=True
582
)
583
584
# Separate DDP for models with different parameter update frequencies
585
model = MMSeparateDistributedDataParallel(
586
model,
587
device_ids=[local_rank],
588
find_unused_parameters=True
589
)
590
```
591
592
### FSDP Usage (PyTorch >=2.0)
593
594
```python
595
from mmengine.model import MMFullyShardedDataParallel
596
from torch.distributed.fsdp import ShardingStrategy, CPUOffload
597
598
# FSDP configuration
599
model = MMFullyShardedDataParallel(
600
model,
601
sharding_strategy=ShardingStrategy.FULL_SHARD,
602
cpu_offload=CPUOffload(offload_params=True),
603
mixed_precision=None,
604
backward_prefetch=None,
605
forward_prefetch=False,
606
limit_all_gathers=True
607
)
608
```
609
610
### Random Seed Synchronization
611
612
```python
613
from mmengine.dist import sync_random_seed
614
615
# Synchronize random seed across all processes
616
seed = sync_random_seed(42)
617
618
# Use synchronized seed
619
torch.manual_seed(seed)
620
np.random.seed(seed)
621
```