or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcore-training.mddistributed.mdindex.mdprecision.mdstrategies.mdutilities.md

strategies.mddocs/

0

# Strategies

1

2

Distributed training strategies that define how models and data are distributed across devices and processes.

3

4

## Capabilities

5

6

### Base Strategy

7

8

Abstract base class defining the strategy interface for distributed training.

9

10

```python { .api }

11

class Strategy:

12

"""

13

Abstract base class for distributed training strategies.

14

15

Strategies define how models, optimizers, and data are distributed

16

across devices and processes for parallel training.

17

"""

18

19

def setup_environment(self) -> None:

20

"""Setup the distributed training environment."""

21

22

def setup_module(self, module: nn.Module) -> nn.Module:

23

"""Setup module for distributed training."""

24

25

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:

26

"""Setup optimizer for distributed training."""

27

28

def module_to_device(self, module: nn.Module) -> None:

29

"""Move module to appropriate device(s)."""

30

31

def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Tensor:

32

"""Reduce tensor across processes."""

33

34

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:

35

"""All-gather tensor across processes."""

36

37

def broadcast(self, tensor: Tensor, src: int = 0) -> Tensor:

38

"""Broadcast tensor from source process."""

39

40

def barrier(self, name: Optional[str] = None) -> None:

41

"""Synchronize all processes."""

42

43

def teardown(self) -> None:

44

"""Clean up strategy resources."""

45

```

46

47

### Single Device Strategy

48

49

Strategy for training on a single device (CPU or GPU).

50

51

```python { .api }

52

class SingleDeviceStrategy(Strategy):

53

"""

54

Strategy for single device training.

55

56

Handles training on a single CPU or GPU without distributed communication.

57

"""

58

59

def __init__(self, device: Optional[torch.device] = None):

60

"""

61

Initialize single device strategy.

62

63

Args:

64

device: Target device for training

65

"""

66

67

def setup_module(self, module: nn.Module) -> nn.Module:

68

"""Move module to target device."""

69

70

def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:

71

"""Return optimizer as-is (no distribution needed)."""

72

```

73

74

### Data Parallel Strategy

75

76

PyTorch DataParallel strategy for single-node multi-GPU training.

77

78

```python { .api }

79

class DataParallelStrategy(Strategy):

80

"""

81

DataParallel strategy for single-node multi-GPU training.

82

83

Uses PyTorch's DataParallel for simple multi-GPU training on single node.

84

Limited scalability compared to DistributedDataParallel.

85

"""

86

87

def __init__(self, parallel_devices: Optional[list[torch.device]] = None):

88

"""

89

Initialize DataParallel strategy.

90

91

Args:

92

parallel_devices: List of devices to use for parallel training

93

"""

94

95

def setup_module(self, module: nn.Module) -> nn.Module:

96

"""Wrap module with DataParallel."""

97

98

def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:

99

"""Reduce tensor across DataParallel devices."""

100

```

101

102

### Distributed Data Parallel Strategy

103

104

PyTorch DistributedDataParallel strategy for scalable multi-GPU training.

105

106

```python { .api }

107

class DDPStrategy(Strategy):

108

"""

109

DistributedDataParallel strategy for scalable multi-GPU training.

110

111

Uses PyTorch's DDP for efficient distributed training across

112

multiple GPUs and nodes with gradient synchronization.

113

"""

114

115

def __init__(

116

self,

117

parallel_devices: Optional[list[torch.device]] = None,

118

cluster_environment: Optional[ClusterEnvironment] = None,

119

checkpoint_io: Optional[CheckpointIO] = None,

120

precision_plugin: Optional[Precision] = None,

121

ddp_comm_state: Optional[object] = None,

122

ddp_comm_hook: Optional[callable] = None,

123

ddp_comm_wrapper: Optional[callable] = None,

124

model_averaging_period: Optional[int] = None,

125

process_group_backend: Optional[str] = None,

126

timeout: Optional[timedelta] = None,

127

**kwargs

128

):

129

"""

130

Initialize DDP strategy.

131

132

Args:

133

parallel_devices: Devices for parallel training

134

cluster_environment: Cluster environment plugin

135

checkpoint_io: Checkpoint I/O plugin

136

precision_plugin: Precision plugin

137

ddp_comm_state: DDP communication state

138

ddp_comm_hook: Custom DDP communication hook

139

ddp_comm_wrapper: DDP communication wrapper

140

model_averaging_period: Period for model averaging

141

process_group_backend: Process group backend (nccl, gloo, mpi)

142

timeout: Timeout for distributed operations

143

"""

144

145

def setup_distributed(self) -> None:

146

"""Initialize distributed process group."""

147

148

def setup_module(self, module: nn.Module) -> nn.Module:

149

"""Wrap module with DistributedDataParallel."""

150

151

def configure_ddp(self) -> None:

152

"""Configure DDP-specific settings."""

153

```

154

155

### DeepSpeed Strategy

156

157

Microsoft DeepSpeed integration for large-scale model training.

158

159

```python { .api }

160

class DeepSpeedStrategy(Strategy):

161

"""

162

DeepSpeed strategy for large-scale model training.

163

164

Integrates with Microsoft DeepSpeed for memory-efficient training

165

of large models using ZeRO optimizer states and gradients partitioning.

166

"""

167

168

def __init__(

169

self,

170

stage: int = 2,

171

remote_device: Optional[str] = None,

172

offload_optimizer: bool = False,

173

offload_parameters: bool = False,

174

offload_params_device: str = "cpu",

175

nvme_path: Optional[str] = None,

176

params_buffer_count: int = 5,

177

params_buffer_size: int = 100_000_000,

178

max_in_cpu: int = 1_000_000_000,

179

offload_optimizer_device: str = "cpu",

180

optimizer_buffer_count: int = 4,

181

block_size: int = 1048576,

182

queue_depth: int = 8,

183

single_submit: bool = False,

184

overlap_events: bool = True,

185

thread_count: int = 1,

186

config: Optional[Union[str, dict]] = None,

187

logging_level: int = logging.WARN,

188

parallel_devices: Optional[list[torch.device]] = None,

189

cluster_environment: Optional[ClusterEnvironment] = None,

190

checkpoint_io: Optional[CheckpointIO] = None,

191

precision_plugin: Optional[Precision] = None,

192

process_group_backend: Optional[str] = None,

193

timeout: Optional[timedelta] = None,

194

**kwargs

195

):

196

"""

197

Initialize DeepSpeed strategy.

198

199

Args:

200

stage: DeepSpeed ZeRO stage (1, 2, or 3)

201

remote_device: Remote device for offloading

202

offload_optimizer: Whether to offload optimizer states

203

offload_parameters: Whether to offload parameters

204

offload_params_device: Device for parameter offloading

205

nvme_path: Path to NVMe storage for offloading

206

config: DeepSpeed configuration dict or path to config file

207

Other args: Additional DeepSpeed configuration options

208

"""

209

210

def setup_module_and_optimizers(

211

self,

212

module: nn.Module,

213

optimizers: list[Optimizer]

214

) -> tuple[nn.Module, list[Optimizer]]:

215

"""Setup module and optimizers with DeepSpeed engine."""

216

217

def configure_deepspeed_config(self, config: dict) -> dict:

218

"""Configure DeepSpeed configuration dictionary."""

219

```

220

221

### FSDP Strategy

222

223

Fully Sharded Data Parallel strategy for memory-efficient large model training.

224

225

```python { .api }

226

class FSDPStrategy(Strategy):

227

"""

228

Fully Sharded Data Parallel strategy for large model training.

229

230

Uses PyTorch's FSDP to shard model parameters, gradients, and

231

optimizer states across devices for memory-efficient training.

232

"""

233

234

def __init__(

235

self,

236

cpu_offload: Optional[bool] = None,

237

mixed_precision: Optional[MixedPrecision] = None,

238

auto_wrap_policy: Optional[callable] = None,

239

activation_checkpointing: Optional[bool] = None,

240

activation_checkpointing_policy: Optional[callable] = None,

241

sharding_strategy: Optional[ShardingStrategy] = None,

242

state_dict_type: Optional[StateDictType] = None,

243

use_orig_params: bool = False,

244

limit_all_gathers: bool = True,

245

sync_module_states: bool = False,

246

forward_prefetch: bool = False,

247

parallel_devices: Optional[list[torch.device]] = None,

248

cluster_environment: Optional[ClusterEnvironment] = None,

249

checkpoint_io: Optional[CheckpointIO] = None,

250

precision_plugin: Optional[Precision] = None,

251

process_group_backend: Optional[str] = None,

252

timeout: Optional[timedelta] = None,

253

**kwargs

254

):

255

"""

256

Initialize FSDP strategy.

257

258

Args:

259

cpu_offload: Whether to offload parameters and gradients to CPU

260

mixed_precision: Mixed precision configuration

261

auto_wrap_policy: Policy for automatic module wrapping

262

activation_checkpointing: Whether to use activation checkpointing

263

sharding_strategy: Parameter sharding strategy

264

state_dict_type: Type of state dict for checkpointing

265

use_orig_params: Whether to use original parameter names

266

"""

267

268

def setup_module(self, module: nn.Module) -> nn.Module:

269

"""Wrap module with FSDP."""

270

271

def configure_fsdp_auto_wrap_policy(self, module: nn.Module) -> Optional[callable]:

272

"""Configure automatic wrapping policy for FSDP."""

273

```

274

275

### XLA Strategy

276

277

XLA (TPU) strategy for training on Google Cloud TPUs.

278

279

```python { .api }

280

class XLAStrategy(Strategy):

281

"""

282

XLA strategy for TPU training using PyTorch XLA.

283

284

Provides TPU support with XLA compilation for high-performance

285

training on Google Cloud TPU pods.

286

"""

287

288

def __init__(

289

self,

290

sync_module_states: bool = True,

291

parallel_devices: Optional[list[torch.device]] = None,

292

cluster_environment: Optional[ClusterEnvironment] = None,

293

checkpoint_io: Optional[CheckpointIO] = None,

294

precision_plugin: Optional[Precision] = None,

295

debug: bool = False,

296

**kwargs

297

):

298

"""

299

Initialize XLA strategy.

300

301

Args:

302

sync_module_states: Whether to sync module states across TPU cores

303

debug: Whether to enable XLA debug mode

304

"""

305

306

def setup_module(self, module: nn.Module) -> nn.Module:

307

"""Setup module for TPU training."""

308

309

def reduce(self, tensor: Tensor, *args, **kwargs) -> Tensor:

310

"""Reduce tensor across TPU cores using XLA collectives."""

311

312

def all_gather(self, tensor: Tensor, *args, **kwargs) -> Tensor:

313

"""All-gather tensor across TPU cores."""

314

315

def mark_step(self) -> None:

316

"""Mark XLA step boundary for graph compilation."""

317

```

318

319

### Single Device XLA Strategy

320

321

Strategy for single XLA device training (TPU, XLA on GPU).

322

323

```python { .api }

324

class SingleDeviceXLAStrategy(Strategy):

325

"""

326

Strategy for training on a single XLA device.

327

328

Optimized for single TPU core or XLA compilation on single GPU.

329

"""

330

331

def __init__(

332

self,

333

device: Optional[torch.device] = None,

334

accelerator: Optional[Accelerator] = None,

335

checkpoint_io: Optional[CheckpointIO] = None,

336

precision_plugin: Optional[Precision] = None

337

):

338

"""Initialize single XLA device strategy."""

339

```

340

341

### Model Parallel Strategy

342

343

Strategy for tensor model parallelism across multiple devices.

344

345

```python { .api }

346

class ModelParallelStrategy(Strategy):

347

"""

348

Strategy for tensor model parallelism.

349

350

Splits individual model layers across multiple devices for very large models

351

that don't fit on a single device.

352

"""

353

354

def __init__(

355

self,

356

accelerator: Optional[Accelerator] = None,

357

checkpoint_io: Optional[CheckpointIO] = None,

358

precision_plugin: Optional[Precision] = None

359

):

360

"""Initialize model parallel strategy."""

361

```

362

363

### Parallel Strategy

364

365

Base class for multi-device parallel strategies.

366

367

```python { .api }

368

class ParallelStrategy(Strategy):

369

"""

370

Base class for parallel training strategies.

371

372

Provides common functionality for strategies that distribute training

373

across multiple devices or processes.

374

"""

375

376

def __init__(

377

self,

378

accelerator: Optional[Accelerator] = None,

379

parallel_devices: Optional[list[torch.device]] = None,

380

checkpoint_io: Optional[CheckpointIO] = None,

381

precision_plugin: Optional[Precision] = None

382

):

383

"""Initialize parallel strategy."""

384

```

385

386

### XLA FSDP Strategy

387

388

Strategy combining XLA compilation with Fully Sharded Data Parallel for TPUs.

389

390

```python { .api }

391

class XLAFSDPStrategy(XLAStrategy):

392

"""

393

Strategy combining XLA with Fully Sharded Data Parallel.

394

395

Provides FSDP sharding capabilities optimized for XLA devices,

396

enabling training of very large models on TPU pods.

397

"""

398

399

def __init__(

400

self,

401

accelerator: Optional[Accelerator] = None,

402

parallel_devices: Optional[list[torch.device]] = None,

403

checkpoint_io: Optional[CheckpointIO] = None,

404

precision_plugin: Optional[Precision] = None,

405

auto_wrap_policy: Optional[Callable] = None,

406

**kwargs

407

):

408

"""Initialize XLA FSDP strategy."""

409

```

410

411

### Strategy Registry

412

413

Global registry for strategy plugins.

414

415

```python { .api }

416

class StrategyRegistry:

417

"""Registry for strategy plugins."""

418

419

def register(

420

self,

421

name: str,

422

strategy_class: type[Strategy],

423

description: Optional[str] = None

424

) -> None:

425

"""Register strategy class."""

426

427

def get(self, name: str) -> type[Strategy]:

428

"""Get strategy class by name."""

429

430

def available_strategies(self) -> list[str]:

431

"""Get list of available strategy names."""

432

433

def remove(self, name: str) -> None:

434

"""Remove strategy from registry."""

435

436

# Global registry instance

437

STRATEGY_REGISTRY: StrategyRegistry

438

```

439

440

## Usage Examples

441

442

### Basic Strategy Selection

443

444

```python

445

from lightning.fabric import Fabric

446

447

# Single device training

448

fabric = Fabric(strategy="auto") # Auto-selects single device

449

450

# Data parallel (single node, multiple GPUs)

451

fabric = Fabric(strategy="dp", devices=4)

452

453

# Distributed data parallel

454

fabric = Fabric(strategy="ddp", devices=4, num_nodes=2)

455

```

456

457

### DeepSpeed Configuration

458

459

```python

460

# DeepSpeed ZeRO Stage 2

461

fabric = Fabric(

462

strategy="deepspeed",

463

devices=8,

464

precision="16-mixed"

465

)

466

467

# DeepSpeed with custom configuration

468

deepspeed_config = {

469

"zero_optimization": {

470

"stage": 3,

471

"offload_optimizer": {"device": "cpu"},

472

"offload_param": {"device": "cpu"}

473

},

474

"train_micro_batch_size_per_gpu": 1

475

}

476

477

fabric = Fabric(

478

strategy=DeepSpeedStrategy(config=deepspeed_config),

479

devices=8

480

)

481

```

482

483

### FSDP Configuration

484

485

```python

486

# FSDP with CPU offloading

487

fabric = Fabric(

488

strategy="fsdp",

489

devices=4,

490

precision="bf16-mixed"

491

)

492

493

# FSDP with custom configuration

494

from torch.distributed.fsdp import MixedPrecision

495

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

496

497

fsdp_strategy = FSDPStrategy(

498

cpu_offload=True,

499

mixed_precision=MixedPrecision(

500

param_dtype=torch.bfloat16,

501

reduce_dtype=torch.bfloat16,

502

buffer_dtype=torch.bfloat16

503

),

504

auto_wrap_policy=transformer_auto_wrap_policy,

505

activation_checkpointing=True

506

)

507

508

fabric = Fabric(strategy=fsdp_strategy, devices=8)

509

```

510

511

### TPU Training

512

513

```python

514

# XLA/TPU training

515

fabric = Fabric(

516

accelerator="tpu",

517

strategy="xla",

518

devices=8,

519

precision="bf16-mixed"

520

)

521

522

# Mark XLA steps for optimal compilation

523

for batch in dataloader:

524

loss = compute_loss(model, batch)

525

fabric.backward(loss)

526

optimizer.step()

527

528

# Mark step boundary for XLA

529

if hasattr(fabric.strategy, 'mark_step'):

530

fabric.strategy.mark_step()

531

```

532

533

### Custom Strategy

534

535

```python

536

from lightning.fabric.strategies import Strategy, STRATEGY_REGISTRY

537

538

class CustomStrategy(Strategy):

539

def setup_module(self, module):

540

# Custom module setup

541

return module

542

543

def reduce(self, tensor, *args, **kwargs):

544

# Custom reduction logic

545

return tensor

546

547

# Register custom strategy

548

STRATEGY_REGISTRY.register("custom", CustomStrategy)

549

550

# Use custom strategy

551

fabric = Fabric(strategy="custom")

552

```

553

554

### Advanced DDP Configuration

555

556

```python

557

from datetime import timedelta

558

559

# DDP with custom settings

560

ddp_strategy = DDPStrategy(

561

process_group_backend="nccl",

562

timeout=timedelta(minutes=30),

563

find_unused_parameters=False, # Set via kwargs

564

gradient_as_bucket_view=True # Set via kwargs

565

)

566

567

fabric = Fabric(

568

strategy=ddp_strategy,

569

devices=4,

570

num_nodes=2

571

)

572

```