0
# Distributed Training
1
2
Comprehensive distributed computing support with multiple backends including native PyTorch DDP, Horovod, and XLA/TPU support. PyTorch Ignite provides a unified API for distributed training across different platforms and scales.
3
4
## Capabilities
5
6
### Initialization and Setup
7
8
Functions for initializing and configuring distributed training backends.
9
10
```python { .api }
11
def initialize(backend=None, **kwargs):
12
"""
13
Initialize distributed backend.
14
15
Parameters:
16
- backend: backend type ('nccl', 'gloo', 'mpi', 'horovod', 'xla-tpu')
17
- **kwargs: backend-specific arguments
18
19
Supported kwargs:
20
- For 'nccl'/'gloo': init_method, rank, world_size, timeout
21
- For 'horovod': no additional arguments
22
- For 'xla-tpu': no additional arguments
23
"""
24
25
def finalize():
26
"""Finalize distributed backend and cleanup resources."""
27
28
def show_config():
29
"""Show current distributed configuration."""
30
```
31
32
### Communication Utilities
33
34
Core distributed communication primitives for data synchronization.
35
36
```python { .api }
37
def sync(group=None):
38
"""
39
Synchronize across all processes.
40
41
Parameters:
42
- group: process group (optional)
43
"""
44
45
def barrier(group=None):
46
"""
47
Synchronization barrier across all processes.
48
49
Parameters:
50
- group: process group (optional)
51
"""
52
53
def broadcast(tensor, src=0, group=None):
54
"""
55
Broadcast tensor from source to all processes.
56
57
Parameters:
58
- tensor: tensor to broadcast
59
- src: source rank
60
- group: process group (optional)
61
62
Returns:
63
Broadcasted tensor
64
"""
65
66
def all_reduce(tensor, group=None, op='SUM'):
67
"""
68
All-reduce operation across all processes.
69
70
Parameters:
71
- tensor: tensor to reduce
72
- group: process group (optional)
73
- op: reduction operation ('SUM', 'PRODUCT', 'MIN', 'MAX')
74
75
Returns:
76
Reduced tensor
77
"""
78
79
def all_gather(tensor, group=None):
80
"""
81
All-gather operation across all processes.
82
83
Parameters:
84
- tensor: tensor to gather
85
- group: process group (optional)
86
87
Returns:
88
List of tensors from all processes
89
"""
90
```
91
92
### Information Queries
93
94
Functions for querying distributed environment information.
95
96
```python { .api }
97
def backend():
98
"""
99
Get current distributed backend name.
100
101
Returns:
102
String name of current backend ('nccl', 'gloo', 'horovod', 'xla-tpu', None)
103
"""
104
105
def available_backends():
106
"""
107
Get list of available distributed backends.
108
109
Returns:
110
List of available backend names
111
"""
112
113
def model_name():
114
"""
115
Get distributed model name.
116
117
Returns:
118
String name of distributed model
119
"""
120
121
def get_rank():
122
"""
123
Get current process rank.
124
125
Returns:
126
Integer rank of current process (0 for single process)
127
"""
128
129
def get_local_rank():
130
"""
131
Get local process rank within node.
132
133
Returns:
134
Integer local rank of current process
135
"""
136
137
def get_world_size():
138
"""
139
Get total number of processes.
140
141
Returns:
142
Integer total number of processes (1 for single process)
143
"""
144
```
145
146
### Process Management
147
148
Functions for spawning and managing distributed processes.
149
150
```python { .api }
151
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
152
"""
153
Spawn distributed processes.
154
155
Parameters:
156
- fn: function to run in each process
157
- args: arguments to pass to function
158
- nprocs: number of processes to spawn
159
- join: whether to join processes
160
- daemon: whether processes are daemons
161
- start_method: process start method
162
163
Returns:
164
Process handles if join=False
165
"""
166
167
class Parallel:
168
"""
169
Parallel execution launcher for distributed training.
170
171
Parameters:
172
- backend: distributed backend to use
173
- nprocs: number of processes
174
- **kwargs: additional backend arguments
175
"""
176
def __init__(self, backend=None, nprocs=None, **kwargs): ...
177
178
def run(self, fn, *args, **kwargs):
179
"""Run function in parallel across processes."""
180
```
181
182
### Data Handling
183
184
Auto-configuration utilities for distributed data loading and model setup.
185
186
```python { .api }
187
def auto_dataloader(dataloader, **kwargs):
188
"""
189
Auto-configure dataloader for distributed training.
190
191
Parameters:
192
- dataloader: original dataloader
193
- **kwargs: additional arguments for DistributedSampler
194
195
Returns:
196
Configured dataloader with distributed sampler
197
"""
198
199
def auto_model(model, sync_bn=False, **kwargs):
200
"""
201
Auto-configure model for distributed training.
202
203
Parameters:
204
- model: PyTorch model
205
- sync_bn: whether to use synchronized batch normalization
206
- **kwargs: additional arguments for DistributedDataParallel
207
208
Returns:
209
Wrapped model for distributed training
210
"""
211
212
def auto_optim(optimizer, **kwargs):
213
"""
214
Auto-configure optimizer for distributed training.
215
216
Parameters:
217
- optimizer: PyTorch optimizer
218
- **kwargs: additional arguments
219
220
Returns:
221
Configured optimizer (or Horovod DistributedOptimizer)
222
"""
223
```
224
225
### Capability Detection
226
227
Functions for checking distributed backend capabilities.
228
229
```python { .api }
230
def has_native_dist_support():
231
"""
232
Check if native PyTorch distributed support is available.
233
234
Returns:
235
Boolean indicating availability
236
"""
237
238
def has_hvd_support():
239
"""
240
Check if Horovod support is available.
241
242
Returns:
243
Boolean indicating availability
244
"""
245
246
def has_xla_support():
247
"""
248
Check if XLA/TPU support is available.
249
250
Returns:
251
Boolean indicating availability
252
"""
253
```
254
255
### Utilities
256
257
Convenience utilities for distributed training workflows.
258
259
```python { .api }
260
def one_rank_only(rank=0, with_barrier=True):
261
"""
262
Decorator to execute function on single rank only.
263
264
Parameters:
265
- rank: rank to execute on (default: 0)
266
- with_barrier: whether to add barrier after execution
267
268
Returns:
269
Decorator function
270
"""
271
```
272
273
## Usage Examples
274
275
### Basic Distributed Training Setup
276
277
```python
278
import ignite.distributed as idist
279
from ignite.engine import create_supervised_trainer
280
281
# Initialize distributed backend
282
idist.initialize()
283
284
# Auto-configure model, optimizer, and dataloader
285
model = idist.auto_model(model)
286
optimizer = idist.auto_optim(optimizer)
287
train_loader = idist.auto_dataloader(train_loader)
288
289
# Create trainer
290
trainer = create_supervised_trainer(model, optimizer, criterion)
291
292
# Run training
293
trainer.run(train_loader, max_epochs=100)
294
295
# Finalize
296
idist.finalize()
297
```
298
299
### Multi-GPU Training with DDP
300
301
```python
302
import torch.multiprocessing as mp
303
import ignite.distributed as idist
304
305
def training(local_rank, config):
306
# Initialize distributed backend
307
idist.initialize("nccl")
308
309
# Setup model and data
310
model = create_model()
311
model = idist.auto_model(model)
312
313
train_loader = create_dataloader()
314
train_loader = idist.auto_dataloader(train_loader)
315
316
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
317
optimizer = idist.auto_optim(optimizer)
318
319
# Create trainer
320
trainer = create_supervised_trainer(model, optimizer, criterion)
321
322
# Add logging only on rank 0
323
@trainer.on(Events.ITERATION_COMPLETED(every=100))
324
@idist.one_rank_only()
325
def log_training(engine):
326
print(f"Rank {idist.rank()}: Iteration {engine.state.iteration}, Loss: {engine.state.output}")
327
328
# Run training
329
trainer.run(train_loader, max_epochs=10)
330
331
# Finalize
332
idist.finalize()
333
334
if __name__ == "__main__":
335
config = {}
336
nprocs = torch.cuda.device_count()
337
mp.spawn(training, args=(config,), nprocs=nprocs)
338
```
339
340
### Horovod Training
341
342
```python
343
import ignite.distributed as idist
344
345
# Initialize Horovod backend
346
idist.initialize("horovod")
347
348
# Auto-configure components
349
model = idist.auto_model(model)
350
optimizer = idist.auto_optim(optimizer) # Wraps with Horovod DistributedOptimizer
351
train_loader = idist.auto_dataloader(train_loader)
352
353
# Broadcast initial parameters
354
idist.all_reduce(torch.tensor(0.0)) # Dummy reduce to ensure initialization
355
356
# Create trainer
357
trainer = create_supervised_trainer(model, optimizer, criterion)
358
359
# Run training
360
trainer.run(train_loader, max_epochs=100)
361
362
# Finalize
363
idist.finalize()
364
```
365
366
### XLA/TPU Training
367
368
```python
369
import ignite.distributed as idist
370
371
# Initialize XLA backend
372
idist.initialize("xla-tpu")
373
374
# Auto-configure for TPU
375
model = idist.auto_model(model)
376
optimizer = idist.auto_optim(optimizer)
377
train_loader = idist.auto_dataloader(train_loader)
378
379
# Create trainer
380
trainer = create_supervised_trainer(model, optimizer, criterion)
381
382
# Run training
383
trainer.run(train_loader, max_epochs=100)
384
385
# Finalize
386
idist.finalize()
387
```
388
389
### Custom Communication
390
391
```python
392
import ignite.distributed as idist
393
import torch
394
395
# Check if distributed
396
if idist.world_size() > 1:
397
# Broadcast model parameters from rank 0
398
for param in model.parameters():
399
idist.broadcast(param.data, src=0)
400
401
# All-reduce gradients
402
for param in model.parameters():
403
if param.grad is not None:
404
idist.all_reduce(param.grad.data)
405
param.grad.data /= idist.world_size()
406
407
# Gather loss from all processes
408
local_loss = torch.tensor(loss_value)
409
all_losses = idist.all_gather(local_loss)
410
avg_loss = torch.mean(torch.stack(all_losses))
411
```
412
413
### Process Management with Spawn
414
415
```python
416
import ignite.distributed as idist
417
418
def train_worker(rank, world_size, config):
419
# Worker function for each process
420
print(f"Worker {rank} of {world_size} started")
421
422
# Initialize backend within worker
423
idist.initialize()
424
425
# Training code here
426
# ...
427
428
idist.finalize()
429
430
# Spawn workers
431
idist.spawn(
432
train_worker,
433
args=(4, config), # world_size, config
434
nprocs=4,
435
join=True
436
)
437
```
438
439
### Conditional Execution
440
441
```python
442
import ignite.distributed as idist
443
444
# Execute only on rank 0
445
@idist.one_rank_only(rank=0)
446
def save_checkpoint():
447
torch.save(model.state_dict(), 'checkpoint.pth')
448
449
# Execute on rank 0 with barrier
450
@idist.one_rank_only(rank=0, with_barrier=True)
451
def log_metrics():
452
print(f"Epoch completed, metrics: {metrics}")
453
454
# Manual rank checking
455
if idist.rank() == 0:
456
print("This runs only on the master process")
457
```