0
# Strategies
1
2
Distributed training strategies that define how models and data are distributed across devices and processes.
3
4
## Capabilities
5
6
### Base Strategy
7
8
Abstract base class defining the strategy interface for distributed training.
9
10
```python { .api }
11
class Strategy:
12
"""
13
Abstract base class for distributed training strategies.
14
15
Strategies define how models, optimizers, and data are distributed
16
across devices and processes for parallel training.
17
"""
18
19
def setup_environment(self) -> None:
20
"""Setup the distributed training environment."""
21
22
def setup_module(self, module: nn.Module) -> nn.Module:
23
"""Setup module for distributed training."""
24
25
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
26
"""Setup optimizer for distributed training."""
27
28
def module_to_device(self, module: nn.Module) -> None:
29
"""Move module to appropriate device(s)."""
30
31
def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Tensor:
32
"""Reduce tensor across processes."""
33
34
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
35
"""All-gather tensor across processes."""
36
37
def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:
38
"""Broadcast tensor from source process."""
39
40
def barrier(self, name: Optional[str] = None) -> None:
41
"""Synchronize all processes."""
42
43
def teardown(self) -> None:
44
"""Clean up strategy resources."""
45
```
46
47
### Single Device Strategy
48
49
Strategy for training on a single device (CPU or GPU).
50
51
```python { .api }
52
class SingleDeviceStrategy(Strategy):
53
"""
54
Strategy for single device training.
55
56
Handles training on a single CPU or GPU without distributed communication.
57
"""
58
59
def __init__(self, device: Optional[torch.device] = None):
60
"""
61
Initialize single device strategy.
62
63
Args:
64
device: Target device for training
65
"""
66
67
def setup_module(self, module: nn.Module) -> nn.Module:
68
"""Move module to target device."""
69
70
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
71
"""Return optimizer as-is (no distribution needed)."""
72
```
73
74
### Data Parallel Strategy
75
76
PyTorch DataParallel strategy for single-node multi-GPU training.
77
78
```python { .api }
79
class DataParallelStrategy(Strategy):
80
"""
81
DataParallel strategy for single-node multi-GPU training.
82
83
Uses PyTorch's DataParallel for simple multi-GPU training on single node.
84
Limited scalability compared to DistributedDataParallel.
85
"""
86
87
def __init__(self, parallel_devices: Optional[list[torch.device]] = None):
88
"""
89
Initialize DataParallel strategy.
90
91
Args:
92
parallel_devices: List of devices to use for parallel training
93
"""
94
95
def setup_module(self, module: nn.Module) -> nn.Module:
96
"""Wrap module with DataParallel."""
97
98
def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
99
"""Reduce tensor across DataParallel devices."""
100
```
101
102
### Distributed Data Parallel Strategy
103
104
PyTorch DistributedDataParallel strategy for scalable multi-GPU training.
105
106
```python { .api }
107
class DDPStrategy(Strategy):
108
"""
109
DistributedDataParallel strategy for scalable multi-GPU training.
110
111
Uses PyTorch's DDP for efficient distributed training across
112
multiple GPUs and nodes with gradient synchronization.
113
"""
114
115
def __init__(
116
self,
117
parallel_devices: Optional[list[torch.device]] = None,
118
cluster_environment: Optional[ClusterEnvironment] = None,
119
checkpoint_io: Optional[CheckpointIO] = None,
120
precision_plugin: Optional[Precision] = None,
121
ddp_comm_state: Optional[object] = None,
122
ddp_comm_hook: Optional[callable] = None,
123
ddp_comm_wrapper: Optional[callable] = None,
124
model_averaging_period: Optional[int] = None,
125
process_group_backend: Optional[str] = None,
126
timeout: Optional[timedelta] = None,
127
**kwargs
128
):
129
"""
130
Initialize DDP strategy.
131
132
Args:
133
parallel_devices: Devices for parallel training
134
cluster_environment: Cluster environment plugin
135
checkpoint_io: Checkpoint I/O plugin
136
precision_plugin: Precision plugin
137
ddp_comm_state: DDP communication state
138
ddp_comm_hook: Custom DDP communication hook
139
ddp_comm_wrapper: DDP communication wrapper
140
model_averaging_period: Period for model averaging
141
process_group_backend: Process group backend (nccl, gloo, mpi)
142
timeout: Timeout for distributed operations
143
"""
144
145
def setup_distributed(self) -> None:
146
"""Initialize distributed process group."""
147
148
def setup_module(self, module: nn.Module) -> nn.Module:
149
"""Wrap module with DistributedDataParallel."""
150
151
def configure_ddp(self) -> None:
152
"""Configure DDP-specific settings."""
153
```
154
155
### DeepSpeed Strategy
156
157
Microsoft DeepSpeed integration for large-scale model training.
158
159
```python { .api }
160
class DeepSpeedStrategy(Strategy):
161
"""
162
DeepSpeed strategy for large-scale model training.
163
164
Integrates with Microsoft DeepSpeed for memory-efficient training
165
of large models using ZeRO optimizer states and gradients partitioning.
166
"""
167
168
def __init__(
169
self,
170
stage: int = 2,
171
remote_device: Optional[str] = None,
172
offload_optimizer: bool = False,
173
offload_parameters: bool = False,
174
offload_params_device: str = "cpu",
175
nvme_path: Optional[str] = None,
176
params_buffer_count: int = 5,
177
params_buffer_size: int = 100_000_000,
178
max_in_cpu: int = 1_000_000_000,
179
offload_optimizer_device: str = "cpu",
180
optimizer_buffer_count: int = 4,
181
block_size: int = 1048576,
182
queue_depth: int = 8,
183
single_submit: bool = False,
184
overlap_events: bool = True,
185
thread_count: int = 1,
186
config: Optional[Union[str, dict]] = None,
187
logging_level: int = logging.WARN,
188
parallel_devices: Optional[list[torch.device]] = None,
189
cluster_environment: Optional[ClusterEnvironment] = None,
190
checkpoint_io: Optional[CheckpointIO] = None,
191
precision_plugin: Optional[Precision] = None,
192
process_group_backend: Optional[str] = None,
193
timeout: Optional[timedelta] = None,
194
**kwargs
195
):
196
"""
197
Initialize DeepSpeed strategy.
198
199
Args:
200
stage: DeepSpeed ZeRO stage (1, 2, or 3)
201
remote_device: Remote device for offloading
202
offload_optimizer: Whether to offload optimizer states
203
offload_parameters: Whether to offload parameters
204
offload_params_device: Device for parameter offloading
205
nvme_path: Path to NVMe storage for offloading
206
config: DeepSpeed configuration dict or path to config file
207
Other args: Additional DeepSpeed configuration options
208
"""
209
210
def setup_module_and_optimizers(
211
self,
212
module: nn.Module,
213
optimizers: list[Optimizer]
214
) -> tuple[nn.Module, list[Optimizer]]:
215
"""Setup module and optimizers with DeepSpeed engine."""
216
217
def configure_deepspeed_config(self, config: dict) -> dict:
218
"""Configure DeepSpeed configuration dictionary."""
219
```
220
221
### FSDP Strategy
222
223
Fully Sharded Data Parallel strategy for memory-efficient large model training.
224
225
```python { .api }
226
class FSDPStrategy(Strategy):
227
"""
228
Fully Sharded Data Parallel strategy for large model training.
229
230
Uses PyTorch's FSDP to shard model parameters, gradients, and
231
optimizer states across devices for memory-efficient training.
232
"""
233
234
def __init__(
235
self,
236
cpu_offload: Optional[bool] = None,
237
mixed_precision: Optional[MixedPrecision] = None,
238
auto_wrap_policy: Optional[callable] = None,
239
activation_checkpointing: Optional[bool] = None,
240
activation_checkpointing_policy: Optional[callable] = None,
241
sharding_strategy: Optional[ShardingStrategy] = None,
242
state_dict_type: Optional[StateDictType] = None,
243
use_orig_params: bool = False,
244
limit_all_gathers: bool = True,
245
sync_module_states: bool = False,
246
forward_prefetch: bool = False,
247
parallel_devices: Optional[list[torch.device]] = None,
248
cluster_environment: Optional[ClusterEnvironment] = None,
249
checkpoint_io: Optional[CheckpointIO] = None,
250
precision_plugin: Optional[Precision] = None,
251
process_group_backend: Optional[str] = None,
252
timeout: Optional[timedelta] = None,
253
**kwargs
254
):
255
"""
256
Initialize FSDP strategy.
257
258
Args:
259
cpu_offload: Whether to offload parameters and gradients to CPU
260
mixed_precision: Mixed precision configuration
261
auto_wrap_policy: Policy for automatic module wrapping
262
activation_checkpointing: Whether to use activation checkpointing
263
sharding_strategy: Parameter sharding strategy
264
state_dict_type: Type of state dict for checkpointing
265
use_orig_params: Whether to use original parameter names
266
"""
267
268
def setup_module(self, module: nn.Module) -> nn.Module:
269
"""Wrap module with FSDP."""
270
271
def configure_fsdp_auto_wrap_policy(self, module: nn.Module) -> Optional[callable]:
272
"""Configure automatic wrapping policy for FSDP."""
273
```
274
275
### XLA Strategy
276
277
XLA (TPU) strategy for training on Google Cloud TPUs.
278
279
```python { .api }
280
class XLAStrategy(Strategy):
281
"""
282
XLA strategy for TPU training using PyTorch XLA.
283
284
Provides TPU support with XLA compilation for high-performance
285
training on Google Cloud TPU pods.
286
"""
287
288
def __init__(
289
self,
290
sync_module_states: bool = True,
291
parallel_devices: Optional[list[torch.device]] = None,
292
cluster_environment: Optional[ClusterEnvironment] = None,
293
checkpoint_io: Optional[CheckpointIO] = None,
294
precision_plugin: Optional[Precision] = None,
295
debug: bool = False,
296
**kwargs
297
):
298
"""
299
Initialize XLA strategy.
300
301
Args:
302
sync_module_states: Whether to sync module states across TPU cores
303
debug: Whether to enable XLA debug mode
304
"""
305
306
def setup_module(self, module: nn.Module) -> nn.Module:
307
"""Setup module for TPU training."""
308
309
def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:
310
"""Reduce tensor across TPU cores using XLA collectives."""
311
312
def all_gather(self, tensor: Tensor, *args, **kwargs) -> Tensor:
313
"""All-gather tensor across TPU cores."""
314
315
def mark_step(self) -> None:
316
"""Mark XLA step boundary for graph compilation."""
317
```
318
319
### Single Device XLA Strategy
320
321
Strategy for single XLA device training (TPU, XLA on GPU).
322
323
```python { .api }
324
class SingleDeviceXLAStrategy(Strategy):
325
"""
326
Strategy for training on a single XLA device.
327
328
Optimized for single TPU core or XLA compilation on single GPU.
329
"""
330
331
def __init__(
332
self,
333
device: Optional[torch.device] = None,
334
accelerator: Optional[Accelerator] = None,
335
checkpoint_io: Optional[CheckpointIO] = None,
336
precision_plugin: Optional[Precision] = None
337
):
338
"""Initialize single XLA device strategy."""
339
```
340
341
### Model Parallel Strategy
342
343
Strategy for tensor model parallelism across multiple devices.
344
345
```python { .api }
346
class ModelParallelStrategy(Strategy):
347
"""
348
Strategy for tensor model parallelism.
349
350
Splits individual model layers across multiple devices for very large models
351
that don't fit on a single device.
352
"""
353
354
def __init__(
355
self,
356
accelerator: Optional[Accelerator] = None,
357
checkpoint_io: Optional[CheckpointIO] = None,
358
precision_plugin: Optional[Precision] = None
359
):
360
"""Initialize model parallel strategy."""
361
```
362
363
### Parallel Strategy
364
365
Base class for multi-device parallel strategies.
366
367
```python { .api }
368
class ParallelStrategy(Strategy):
369
"""
370
Base class for parallel training strategies.
371
372
Provides common functionality for strategies that distribute training
373
across multiple devices or processes.
374
"""
375
376
def __init__(
377
self,
378
accelerator: Optional[Accelerator] = None,
379
parallel_devices: Optional[list[torch.device]] = None,
380
checkpoint_io: Optional[CheckpointIO] = None,
381
precision_plugin: Optional[Precision] = None
382
):
383
"""Initialize parallel strategy."""
384
```
385
386
### XLA FSDP Strategy
387
388
Strategy combining XLA compilation with Fully Sharded Data Parallel for TPUs.
389
390
```python { .api }
391
class XLAFSDPStrategy(XLAStrategy):
392
"""
393
Strategy combining XLA with Fully Sharded Data Parallel.
394
395
Provides FSDP sharding capabilities optimized for XLA devices,
396
enabling training of very large models on TPU pods.
397
"""
398
399
def __init__(
400
self,
401
accelerator: Optional[Accelerator] = None,
402
parallel_devices: Optional[list[torch.device]] = None,
403
checkpoint_io: Optional[CheckpointIO] = None,
404
precision_plugin: Optional[Precision] = None,
405
auto_wrap_policy: Optional[Callable] = None,
406
**kwargs
407
):
408
"""Initialize XLA FSDP strategy."""
409
```
410
411
### Strategy Registry
412
413
Global registry for strategy plugins.
414
415
```python { .api }
416
class StrategyRegistry:
417
"""Registry for strategy plugins."""
418
419
def register(
420
self,
421
name: str,
422
strategy_class: type[Strategy],
423
description: Optional[str] = None
424
) -> None:
425
"""Register strategy class."""
426
427
def get(self, name: str) -> type[Strategy]:
428
"""Get strategy class by name."""
429
430
def available_strategies(self) -> list[str]:
431
"""Get list of available strategy names."""
432
433
def remove(self, name: str) -> None:
434
"""Remove strategy from registry."""
435
436
# Global registry instance
437
STRATEGY_REGISTRY: StrategyRegistry
438
```
439
440
## Usage Examples
441
442
### Basic Strategy Selection
443
444
```python
445
from lightning.fabric import Fabric
446
447
# Single device training
448
fabric = Fabric(strategy="auto") # Auto-selects single device
449
450
# Data parallel (single node, multiple GPUs)
451
fabric = Fabric(strategy="dp", devices=4)
452
453
# Distributed data parallel
454
fabric = Fabric(strategy="ddp", devices=4, num_nodes=2)
455
```
456
457
### DeepSpeed Configuration
458
459
```python
460
# DeepSpeed ZeRO Stage 2
461
fabric = Fabric(
462
strategy="deepspeed",
463
devices=8,
464
precision="16-mixed"
465
)
466
467
# DeepSpeed with custom configuration
468
deepspeed_config = {
469
"zero_optimization": {
470
"stage": 3,
471
"offload_optimizer": {"device": "cpu"},
472
"offload_param": {"device": "cpu"}
473
},
474
"train_micro_batch_size_per_gpu": 1
475
}
476
477
fabric = Fabric(
478
strategy=DeepSpeedStrategy(config=deepspeed_config),
479
devices=8
480
)
481
```
482
483
### FSDP Configuration
484
485
```python
486
# FSDP with CPU offloading
487
fabric = Fabric(
488
strategy="fsdp",
489
devices=4,
490
precision="bf16-mixed"
491
)
492
493
# FSDP with custom configuration
494
from torch.distributed.fsdp import MixedPrecision
495
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
496
497
fsdp_strategy = FSDPStrategy(
498
cpu_offload=True,
499
mixed_precision=MixedPrecision(
500
param_dtype=torch.bfloat16,
501
reduce_dtype=torch.bfloat16,
502
buffer_dtype=torch.bfloat16
503
),
504
auto_wrap_policy=transformer_auto_wrap_policy,
505
activation_checkpointing=True
506
)
507
508
fabric = Fabric(strategy=fsdp_strategy, devices=8)
509
```
510
511
### TPU Training
512
513
```python
514
# XLA/TPU training
515
fabric = Fabric(
516
accelerator="tpu",
517
strategy="xla",
518
devices=8,
519
precision="bf16-mixed"
520
)
521
522
# Mark XLA steps for optimal compilation
523
for batch in dataloader:
524
loss = compute_loss(model, batch)
525
fabric.backward(loss)
526
optimizer.step()
527
528
# Mark step boundary for XLA
529
if hasattr(fabric.strategy, 'mark_step'):
530
fabric.strategy.mark_step()
531
```
532
533
### Custom Strategy
534
535
```python
536
from lightning.fabric.strategies import Strategy, STRATEGY_REGISTRY
537
538
class CustomStrategy(Strategy):
539
def setup_module(self, module):
540
# Custom module setup
541
return module
542
543
def reduce(self, tensor, *args, **kwargs):
544
# Custom reduction logic
545
return tensor
546
547
# Register custom strategy
548
STRATEGY_REGISTRY.register("custom", CustomStrategy)
549
550
# Use custom strategy
551
fabric = Fabric(strategy="custom")
552
```
553
554
### Advanced DDP Configuration
555
556
```python
557
from datetime import timedelta
558
559
# DDP with custom settings
560
ddp_strategy = DDPStrategy(
561
process_group_backend="nccl",
562
timeout=timedelta(minutes=30),
563
find_unused_parameters=False, # Set via kwargs
564
gradient_as_bucket_view=True # Set via kwargs
565
)
566
567
fabric = Fabric(
568
strategy=ddp_strategy,
569
devices=4,
570
num_nodes=2
571
)
572
```