0
# Distributed Training Strategies
1
2
Multiple strategies for distributed and parallel training including data parallel, distributed data parallel, fully sharded data parallel, model parallel, and specialized strategies for different hardware configurations.
3
4
## Capabilities
5
6
### Distributed Data Parallel (DDP)
7
8
Multi-GPU and multi-node distributed training strategy that replicates the model across devices and synchronizes gradients.
9
10
```python { .api }
11
class DDPStrategy:
12
def __init__(
13
self,
14
accelerator: Optional[Accelerator] = None,
15
parallel_devices: Optional[List[torch.device]] = None,
16
cluster_environment: Optional[ClusterEnvironment] = None,
17
checkpoint_io: Optional[CheckpointIO] = None,
18
precision_plugin: Optional[PrecisionPlugin] = None,
19
ddp_comm_state: Optional[object] = None,
20
ddp_comm_hook: Optional[Callable] = None,
21
ddp_comm_wrapper: Optional[Callable] = None,
22
model_averaging_period: Optional[int] = None,
23
process_group_backend: Optional[str] = None,
24
timeout: Optional[timedelta] = None,
25
start_method: str = "popen",
26
**kwargs
27
):
28
"""
29
Initialize DDP strategy.
30
31
Args:
32
accelerator: Hardware accelerator to use
33
parallel_devices: List of devices for parallel training
34
cluster_environment: Cluster configuration
35
checkpoint_io: Checkpoint I/O plugin
36
precision_plugin: Precision plugin for mixed precision
37
ddp_comm_state: DDP communication state
38
ddp_comm_hook: Custom communication hook
39
ddp_comm_wrapper: Communication wrapper
40
model_averaging_period: Period for model averaging
41
process_group_backend: Backend for process group ('nccl', 'gloo')
42
timeout: Timeout for distributed operations
43
start_method: Method to start processes
44
"""
45
```
46
47
### Fully Sharded Data Parallel (FSDP)
48
49
Memory-efficient distributed training that shards model parameters, gradients, and optimizer states across devices.
50
51
```python { .api }
52
class FSDPStrategy:
53
def __init__(
54
self,
55
accelerator: Optional[Accelerator] = None,
56
parallel_devices: Optional[List[torch.device]] = None,
57
cluster_environment: Optional[ClusterEnvironment] = None,
58
checkpoint_io: Optional[CheckpointIO] = None,
59
precision_plugin: Optional[PrecisionPlugin] = None,
60
process_group_backend: Optional[str] = None,
61
timeout: Optional[timedelta] = None,
62
auto_wrap_policy: Optional[Callable] = None,
63
cpu_offload: Union[bool, CPUOffload] = False,
64
mixed_precision: Optional[MixedPrecision] = None,
65
sharding_strategy: Union[ShardingStrategy, str] = "FULL_SHARD",
66
backward_prefetch: Optional[BackwardPrefetch] = None,
67
forward_prefetch: bool = False,
68
limit_all_gathers: bool = True,
69
use_orig_params: bool = True,
70
param_init_fn: Optional[Callable] = None,
71
sync_module_states: bool = False,
72
**kwargs
73
):
74
"""
75
Initialize FSDP strategy.
76
77
Args:
78
accelerator: Hardware accelerator to use
79
parallel_devices: List of devices for parallel training
80
cluster_environment: Cluster configuration
81
checkpoint_io: Checkpoint I/O plugin
82
precision_plugin: Precision plugin
83
process_group_backend: Backend for process group
84
timeout: Timeout for distributed operations
85
auto_wrap_policy: Policy for automatic module wrapping
86
cpu_offload: Enable CPU offloading of parameters
87
mixed_precision: Mixed precision configuration
88
sharding_strategy: Strategy for parameter sharding
89
backward_prefetch: Prefetch strategy for backward pass
90
forward_prefetch: Enable forward prefetching
91
limit_all_gathers: Limit all-gather operations
92
use_orig_params: Use original parameters
93
param_init_fn: Parameter initialization function
94
sync_module_states: Synchronize module states
95
"""
96
```
97
98
### DeepSpeed Integration
99
100
Integration with Microsoft DeepSpeed for memory-efficient training of large models with advanced optimization techniques.
101
102
```python { .api }
103
class DeepSpeedStrategy:
104
def __init__(
105
self,
106
accelerator: Optional[Accelerator] = None,
107
zero_optimization: bool = True,
108
stage: int = 2,
109
remote_device: Optional[str] = None,
110
offload_optimizer: bool = False,
111
offload_parameters: bool = False,
112
offload_params_device: str = "cpu",
113
nvme_path: str = "/local_nvme",
114
params_buffer_count: int = 5,
115
params_buffer_size: int = 100_000_000,
116
max_in_cpu: int = 1_000_000_000,
117
offload_optimizer_device: str = "cpu",
118
optimizer_buffer_count: int = 4,
119
block_size: int = 1048576,
120
queue_depth: int = 8,
121
single_submit: bool = False,
122
overlap_events: bool = True,
123
thread_count: int = 1,
124
pin_memory: bool = False,
125
sub_group_size: int = 1_000_000_000_000,
126
cpu_checkpointing: bool = False,
127
contiguous_gradients: bool = True,
128
overlap_comm: bool = True,
129
allgather_partitions: bool = True,
130
reduce_scatter: bool = True,
131
allgather_bucket_size: int = 200_000_000,
132
reduce_bucket_size: int = 200_000_000,
133
zero_allow_untested_optimizer: bool = True,
134
logging_batch_size_per_gpu: str = "auto",
135
config: Optional[Union[Path, str, Dict]] = None,
136
logging_level: int = logging.WARN,
137
parallel_devices: Optional[List[torch.device]] = None,
138
cluster_environment: Optional[ClusterEnvironment] = None,
139
checkpoint_io: Optional[CheckpointIO] = None,
140
precision_plugin: Optional[PrecisionPlugin] = None,
141
process_group_backend: Optional[str] = None,
142
**kwargs
143
):
144
"""
145
Initialize DeepSpeed strategy.
146
147
Args:
148
accelerator: Hardware accelerator to use
149
zero_optimization: Enable ZeRO optimization
150
stage: ZeRO stage (1, 2, or 3)
151
remote_device: Remote device for parameter storage
152
offload_optimizer: Offload optimizer to CPU
153
offload_parameters: Offload parameters to CPU
154
offload_params_device: Device for parameter offloading
155
nvme_path: Path for NVMe offloading
156
params_buffer_count: Number of parameter buffers
157
params_buffer_size: Size of parameter buffers
158
max_in_cpu: Maximum parameters in CPU memory
159
offload_optimizer_device: Device for optimizer offloading
160
config: DeepSpeed configuration file or dictionary
161
logging_level: Logging level for DeepSpeed
162
parallel_devices: List of devices for parallel training
163
cluster_environment: Cluster configuration
164
checkpoint_io: Checkpoint I/O plugin
165
precision_plugin: Precision plugin
166
process_group_backend: Backend for process group
167
"""
168
```
169
170
### Data Parallel Strategy
171
172
Simple data parallelism that replicates the model on multiple devices and averages gradients.
173
174
```python { .api }
175
class DataParallelStrategy:
176
def __init__(
177
self,
178
accelerator: Optional[Accelerator] = None,
179
parallel_devices: Optional[List[torch.device]] = None,
180
checkpoint_io: Optional[CheckpointIO] = None,
181
precision_plugin: Optional[PrecisionPlugin] = None
182
):
183
"""
184
Initialize DataParallel strategy.
185
186
Args:
187
accelerator: Hardware accelerator to use
188
parallel_devices: List of devices for parallel training
189
checkpoint_io: Checkpoint I/O plugin
190
precision_plugin: Precision plugin
191
"""
192
```
193
194
### Single Device Strategy
195
196
Strategy for training on a single device (CPU or GPU).
197
198
```python { .api }
199
class SingleDeviceStrategy:
200
def __init__(
201
self,
202
device: torch.device,
203
accelerator: Optional[Accelerator] = None,
204
checkpoint_io: Optional[CheckpointIO] = None,
205
precision_plugin: Optional[PrecisionPlugin] = None
206
):
207
"""
208
Initialize single device strategy.
209
210
Args:
211
device: Device to use for training
212
accelerator: Hardware accelerator to use
213
checkpoint_io: Checkpoint I/O plugin
214
precision_plugin: Precision plugin
215
"""
216
```
217
218
### XLA Strategies
219
220
Strategies for Google TPU training using XLA compilation.
221
222
```python { .api }
223
class XLAStrategy:
224
def __init__(
225
self,
226
accelerator: Optional[Accelerator] = None,
227
parallel_devices: Optional[List[torch.device]] = None,
228
checkpoint_io: Optional[CheckpointIO] = None,
229
precision_plugin: Optional[PrecisionPlugin] = None,
230
debug: bool = False,
231
sync_module_states: bool = True
232
):
233
"""
234
Initialize XLA strategy for multi-TPU training.
235
236
Args:
237
accelerator: XLA accelerator
238
parallel_devices: List of TPU devices
239
checkpoint_io: Checkpoint I/O plugin
240
precision_plugin: Precision plugin
241
debug: Enable debug mode
242
sync_module_states: Synchronize module states
243
"""
244
245
class SingleDeviceXLAStrategy:
246
def __init__(
247
self,
248
device: torch.device,
249
accelerator: Optional[Accelerator] = None,
250
checkpoint_io: Optional[CheckpointIO] = None,
251
precision_plugin: Optional[PrecisionPlugin] = None,
252
debug: bool = False
253
):
254
"""
255
Initialize single TPU device strategy.
256
257
Args:
258
device: TPU device to use
259
accelerator: XLA accelerator
260
checkpoint_io: Checkpoint I/O plugin
261
precision_plugin: Precision plugin
262
debug: Enable debug mode
263
"""
264
265
class XLAFSDPStrategy:
266
def __init__(
267
self,
268
accelerator: Optional[Accelerator] = None,
269
parallel_devices: Optional[List[torch.device]] = None,
270
checkpoint_io: Optional[CheckpointIO] = None,
271
precision_plugin: Optional[PrecisionPlugin] = None,
272
**kwargs
273
):
274
"""
275
Initialize XLA FSDP strategy combining XLA with fully sharded data parallel.
276
277
Args:
278
accelerator: XLA accelerator
279
parallel_devices: List of TPU devices
280
checkpoint_io: Checkpoint I/O plugin
281
precision_plugin: Precision plugin
282
"""
283
```
284
285
### Model Parallel Strategy
286
287
Strategy for model parallelism where different parts of the model are placed on different devices.
288
289
```python { .api }
290
class ModelParallelStrategy:
291
def __init__(
292
self,
293
accelerator: Optional[Accelerator] = None,
294
parallel_devices: Optional[List[torch.device]] = None,
295
checkpoint_io: Optional[CheckpointIO] = None,
296
precision_plugin: Optional[PrecisionPlugin] = None
297
):
298
"""
299
Initialize model parallel strategy.
300
301
Args:
302
accelerator: Hardware accelerator to use
303
parallel_devices: List of devices for model placement
304
checkpoint_io: Checkpoint I/O plugin
305
precision_plugin: Precision plugin
306
"""
307
```
308
309
### Base Strategy Classes
310
311
Base classes for creating custom training strategies.
312
313
```python { .api }
314
class Strategy:
315
def __init__(
316
self,
317
accelerator: Optional[Accelerator] = None,
318
checkpoint_io: Optional[CheckpointIO] = None,
319
precision_plugin: Optional[PrecisionPlugin] = None
320
):
321
"""
322
Base strategy class.
323
324
Args:
325
accelerator: Hardware accelerator
326
checkpoint_io: Checkpoint I/O plugin
327
precision_plugin: Precision plugin
328
"""
329
330
def setup_environment(self) -> None:
331
"""Set up the training environment."""
332
333
def setup(self, trainer: Trainer) -> None:
334
"""Set up the strategy with trainer."""
335
336
def teardown(self) -> None:
337
"""Clean up the strategy."""
338
339
class ParallelStrategy(Strategy):
340
def __init__(
341
self,
342
accelerator: Optional[Accelerator] = None,
343
parallel_devices: Optional[List[torch.device]] = None,
344
cluster_environment: Optional[ClusterEnvironment] = None,
345
checkpoint_io: Optional[CheckpointIO] = None,
346
precision_plugin: Optional[PrecisionPlugin] = None
347
):
348
"""
349
Base parallel strategy class.
350
351
Args:
352
accelerator: Hardware accelerator
353
parallel_devices: List of devices for parallel training
354
cluster_environment: Cluster configuration
355
checkpoint_io: Checkpoint I/O plugin
356
precision_plugin: Precision plugin
357
"""
358
359
@property
360
def global_rank(self) -> int:
361
"""Global rank of the current process."""
362
363
@property
364
def local_rank(self) -> int:
365
"""Local rank of the current process."""
366
367
@property
368
def world_size(self) -> int:
369
"""Total number of processes."""
370
371
def all_gather(self, tensor: torch.Tensor, sync_grads: bool = False) -> torch.Tensor:
372
"""Gather tensor from all processes."""
373
374
def all_reduce(self, tensor: torch.Tensor, reduce_op: str = "mean") -> torch.Tensor:
375
"""Reduce tensor across all processes."""
376
377
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
378
"""Broadcast tensor from source to all processes."""
379
380
def barrier(self, name: Optional[str] = None) -> None:
381
"""Synchronize all processes."""
382
```
383
384
## Usage Examples
385
386
### Basic Strategy Usage
387
388
```python
389
from lightning import Trainer
390
391
# Use DDP strategy
392
trainer = Trainer(
393
accelerator="gpu",
394
devices=4,
395
strategy="ddp"
396
)
397
398
# Use FSDP strategy
399
trainer = Trainer(
400
accelerator="gpu",
401
devices=8,
402
strategy="fsdp"
403
)
404
```
405
406
### Advanced Strategy Configuration
407
408
```python
409
from lightning import Trainer
410
from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy
411
from datetime import timedelta
412
413
# Configure DDP with custom settings
414
ddp_strategy = DDPStrategy(
415
process_group_backend="nccl",
416
timeout=timedelta(seconds=1800),
417
start_method="spawn"
418
)
419
420
trainer = Trainer(
421
accelerator="gpu",
422
devices=4,
423
strategy=ddp_strategy,
424
precision="16-mixed"
425
)
426
427
# Configure FSDP with CPU offloading
428
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
429
430
fsdp_strategy = FSDPStrategy(
431
sharding_strategy=ShardingStrategy.FULL_SHARD,
432
cpu_offload=CPUOffload(offload_params=True),
433
mixed_precision=None, # Let Lightning handle precision
434
auto_wrap_policy=None # Use default wrapping
435
)
436
437
trainer = Trainer(
438
accelerator="gpu",
439
devices=8,
440
strategy=fsdp_strategy,
441
precision="bf16-mixed"
442
)
443
```
444
445
### DeepSpeed Configuration
446
447
```python
448
from lightning import Trainer
449
from lightning.pytorch.strategies import DeepSpeedStrategy
450
451
# DeepSpeed ZeRO Stage 3 with offloading
452
deepspeed_strategy = DeepSpeedStrategy(
453
stage=3,
454
offload_optimizer=True,
455
offload_parameters=True,
456
remote_device="nvme",
457
nvme_path="/local_nvme"
458
)
459
460
trainer = Trainer(
461
accelerator="gpu",
462
devices=8,
463
strategy=deepspeed_strategy,
464
precision="16-mixed"
465
)
466
467
# DeepSpeed with custom config file
468
trainer = Trainer(
469
accelerator="gpu",
470
devices=8,
471
strategy=DeepSpeedStrategy(config="deepspeed_config.json"),
472
precision="16-mixed"
473
)
474
```