or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

accelerators.mdcallbacks.mdcore-training.mddata.mdfabric.mdindex.mdloggers.mdprecision.mdprofilers.mdstrategies.md

strategies.mddocs/

0

# Distributed Training Strategies

1

2

Multiple strategies for distributed and parallel training including data parallel, distributed data parallel, fully sharded data parallel, model parallel, and specialized strategies for different hardware configurations.

3

4

## Capabilities

5

6

### Distributed Data Parallel (DDP)

7

8

Multi-GPU and multi-node distributed training strategy that replicates the model across devices and synchronizes gradients.

9

10

```python { .api }

11

class DDPStrategy:

12

def __init__(

13

self,

14

accelerator: Optional[Accelerator] = None,

15

parallel_devices: Optional[List[torch.device]] = None,

16

cluster_environment: Optional[ClusterEnvironment] = None,

17

checkpoint_io: Optional[CheckpointIO] = None,

18

precision_plugin: Optional[PrecisionPlugin] = None,

19

ddp_comm_state: Optional[object] = None,

20

ddp_comm_hook: Optional[Callable] = None,

21

ddp_comm_wrapper: Optional[Callable] = None,

22

model_averaging_period: Optional[int] = None,

23

process_group_backend: Optional[str] = None,

24

timeout: Optional[timedelta] = None,

25

start_method: str = "popen",

26

**kwargs

27

):

28

"""

29

Initialize DDP strategy.

30

31

Args:

32

accelerator: Hardware accelerator to use

33

parallel_devices: List of devices for parallel training

34

cluster_environment: Cluster configuration

35

checkpoint_io: Checkpoint I/O plugin

36

precision_plugin: Precision plugin for mixed precision

37

ddp_comm_state: DDP communication state

38

ddp_comm_hook: Custom communication hook

39

ddp_comm_wrapper: Communication wrapper

40

model_averaging_period: Period for model averaging

41

process_group_backend: Backend for process group ('nccl', 'gloo')

42

timeout: Timeout for distributed operations

43

start_method: Method to start processes

44

"""

45

```

46

47

### Fully Sharded Data Parallel (FSDP)

48

49

Memory-efficient distributed training that shards model parameters, gradients, and optimizer states across devices.

50

51

```python { .api }

52

class FSDPStrategy:

53

def __init__(

54

self,

55

accelerator: Optional[Accelerator] = None,

56

parallel_devices: Optional[List[torch.device]] = None,

57

cluster_environment: Optional[ClusterEnvironment] = None,

58

checkpoint_io: Optional[CheckpointIO] = None,

59

precision_plugin: Optional[PrecisionPlugin] = None,

60

process_group_backend: Optional[str] = None,

61

timeout: Optional[timedelta] = None,

62

auto_wrap_policy: Optional[Callable] = None,

63

cpu_offload: Union[bool, CPUOffload] = False,

64

mixed_precision: Optional[MixedPrecision] = None,

65

sharding_strategy: Union[ShardingStrategy, str] = "FULL_SHARD",

66

backward_prefetch: Optional[BackwardPrefetch] = None,

67

forward_prefetch: bool = False,

68

limit_all_gathers: bool = True,

69

use_orig_params: bool = True,

70

param_init_fn: Optional[Callable] = None,

71

sync_module_states: bool = False,

72

**kwargs

73

):

74

"""

75

Initialize FSDP strategy.

76

77

Args:

78

accelerator: Hardware accelerator to use

79

parallel_devices: List of devices for parallel training

80

cluster_environment: Cluster configuration

81

checkpoint_io: Checkpoint I/O plugin

82

precision_plugin: Precision plugin

83

process_group_backend: Backend for process group

84

timeout: Timeout for distributed operations

85

auto_wrap_policy: Policy for automatic module wrapping

86

cpu_offload: Enable CPU offloading of parameters

87

mixed_precision: Mixed precision configuration

88

sharding_strategy: Strategy for parameter sharding

89

backward_prefetch: Prefetch strategy for backward pass

90

forward_prefetch: Enable forward prefetching

91

limit_all_gathers: Limit all-gather operations

92

use_orig_params: Use original parameters

93

param_init_fn: Parameter initialization function

94

sync_module_states: Synchronize module states

95

"""

96

```

97

98

### DeepSpeed Integration

99

100

Integration with Microsoft DeepSpeed for memory-efficient training of large models with advanced optimization techniques.

101

102

```python { .api }

103

class DeepSpeedStrategy:

104

def __init__(

105

self,

106

accelerator: Optional[Accelerator] = None,

107

zero_optimization: bool = True,

108

stage: int = 2,

109

remote_device: Optional[str] = None,

110

offload_optimizer: bool = False,

111

offload_parameters: bool = False,

112

offload_params_device: str = "cpu",

113

nvme_path: str = "/local_nvme",

114

params_buffer_count: int = 5,

115

params_buffer_size: int = 100_000_000,

116

max_in_cpu: int = 1_000_000_000,

117

offload_optimizer_device: str = "cpu",

118

optimizer_buffer_count: int = 4,

119

block_size: int = 1048576,

120

queue_depth: int = 8,

121

single_submit: bool = False,

122

overlap_events: bool = True,

123

thread_count: int = 1,

124

pin_memory: bool = False,

125

sub_group_size: int = 1_000_000_000_000,

126

cpu_checkpointing: bool = False,

127

contiguous_gradients: bool = True,

128

overlap_comm: bool = True,

129

allgather_partitions: bool = True,

130

reduce_scatter: bool = True,

131

allgather_bucket_size: int = 200_000_000,

132

reduce_bucket_size: int = 200_000_000,

133

zero_allow_untested_optimizer: bool = True,

134

logging_batch_size_per_gpu: str = "auto",

135

config: Optional[Union[Path, str, Dict]] = None,

136

logging_level: int = logging.WARN,

137

parallel_devices: Optional[List[torch.device]] = None,

138

cluster_environment: Optional[ClusterEnvironment] = None,

139

checkpoint_io: Optional[CheckpointIO] = None,

140

precision_plugin: Optional[PrecisionPlugin] = None,

141

process_group_backend: Optional[str] = None,

142

**kwargs

143

):

144

"""

145

Initialize DeepSpeed strategy.

146

147

Args:

148

accelerator: Hardware accelerator to use

149

zero_optimization: Enable ZeRO optimization

150

stage: ZeRO stage (1, 2, or 3)

151

remote_device: Remote device for parameter storage

152

offload_optimizer: Offload optimizer to CPU

153

offload_parameters: Offload parameters to CPU

154

offload_params_device: Device for parameter offloading

155

nvme_path: Path for NVMe offloading

156

params_buffer_count: Number of parameter buffers

157

params_buffer_size: Size of parameter buffers

158

max_in_cpu: Maximum parameters in CPU memory

159

offload_optimizer_device: Device for optimizer offloading

160

config: DeepSpeed configuration file or dictionary

161

logging_level: Logging level for DeepSpeed

162

parallel_devices: List of devices for parallel training

163

cluster_environment: Cluster configuration

164

checkpoint_io: Checkpoint I/O plugin

165

precision_plugin: Precision plugin

166

process_group_backend: Backend for process group

167

"""

168

```

169

170

### Data Parallel Strategy

171

172

Simple data parallelism that replicates the model on multiple devices and averages gradients.

173

174

```python { .api }

175

class DataParallelStrategy:

176

def __init__(

177

self,

178

accelerator: Optional[Accelerator] = None,

179

parallel_devices: Optional[List[torch.device]] = None,

180

checkpoint_io: Optional[CheckpointIO] = None,

181

precision_plugin: Optional[PrecisionPlugin] = None

182

):

183

"""

184

Initialize DataParallel strategy.

185

186

Args:

187

accelerator: Hardware accelerator to use

188

parallel_devices: List of devices for parallel training

189

checkpoint_io: Checkpoint I/O plugin

190

precision_plugin: Precision plugin

191

"""

192

```

193

194

### Single Device Strategy

195

196

Strategy for training on a single device (CPU or GPU).

197

198

```python { .api }

199

class SingleDeviceStrategy:

200

def __init__(

201

self,

202

device: torch.device,

203

accelerator: Optional[Accelerator] = None,

204

checkpoint_io: Optional[CheckpointIO] = None,

205

precision_plugin: Optional[PrecisionPlugin] = None

206

):

207

"""

208

Initialize single device strategy.

209

210

Args:

211

device: Device to use for training

212

accelerator: Hardware accelerator to use

213

checkpoint_io: Checkpoint I/O plugin

214

precision_plugin: Precision plugin

215

"""

216

```

217

218

### XLA Strategies

219

220

Strategies for Google TPU training using XLA compilation.

221

222

```python { .api }

223

class XLAStrategy:

224

def __init__(

225

self,

226

accelerator: Optional[Accelerator] = None,

227

parallel_devices: Optional[List[torch.device]] = None,

228

checkpoint_io: Optional[CheckpointIO] = None,

229

precision_plugin: Optional[PrecisionPlugin] = None,

230

debug: bool = False,

231

sync_module_states: bool = True

232

):

233

"""

234

Initialize XLA strategy for multi-TPU training.

235

236

Args:

237

accelerator: XLA accelerator

238

parallel_devices: List of TPU devices

239

checkpoint_io: Checkpoint I/O plugin

240

precision_plugin: Precision plugin

241

debug: Enable debug mode

242

sync_module_states: Synchronize module states

243

"""

244

245

class SingleDeviceXLAStrategy:

246

def __init__(

247

self,

248

device: torch.device,

249

accelerator: Optional[Accelerator] = None,

250

checkpoint_io: Optional[CheckpointIO] = None,

251

precision_plugin: Optional[PrecisionPlugin] = None,

252

debug: bool = False

253

):

254

"""

255

Initialize single TPU device strategy.

256

257

Args:

258

device: TPU device to use

259

accelerator: XLA accelerator

260

checkpoint_io: Checkpoint I/O plugin

261

precision_plugin: Precision plugin

262

debug: Enable debug mode

263

"""

264

265

class XLAFSDPStrategy:

266

def __init__(

267

self,

268

accelerator: Optional[Accelerator] = None,

269

parallel_devices: Optional[List[torch.device]] = None,

270

checkpoint_io: Optional[CheckpointIO] = None,

271

precision_plugin: Optional[PrecisionPlugin] = None,

272

**kwargs

273

):

274

"""

275

Initialize XLA FSDP strategy combining XLA with fully sharded data parallel.

276

277

Args:

278

accelerator: XLA accelerator

279

parallel_devices: List of TPU devices

280

checkpoint_io: Checkpoint I/O plugin

281

precision_plugin: Precision plugin

282

"""

283

```

284

285

### Model Parallel Strategy

286

287

Strategy for model parallelism where different parts of the model are placed on different devices.

288

289

```python { .api }

290

class ModelParallelStrategy:

291

def __init__(

292

self,

293

accelerator: Optional[Accelerator] = None,

294

parallel_devices: Optional[List[torch.device]] = None,

295

checkpoint_io: Optional[CheckpointIO] = None,

296

precision_plugin: Optional[PrecisionPlugin] = None

297

):

298

"""

299

Initialize model parallel strategy.

300

301

Args:

302

accelerator: Hardware accelerator to use

303

parallel_devices: List of devices for model placement

304

checkpoint_io: Checkpoint I/O plugin

305

precision_plugin: Precision plugin

306

"""

307

```

308

309

### Base Strategy Classes

310

311

Base classes for creating custom training strategies.

312

313

```python { .api }

314

class Strategy:

315

def __init__(

316

self,

317

accelerator: Optional[Accelerator] = None,

318

checkpoint_io: Optional[CheckpointIO] = None,

319

precision_plugin: Optional[PrecisionPlugin] = None

320

):

321

"""

322

Base strategy class.

323

324

Args:

325

accelerator: Hardware accelerator

326

checkpoint_io: Checkpoint I/O plugin

327

precision_plugin: Precision plugin

328

"""

329

330

def setup_environment(self) -> None:

331

"""Set up the training environment."""

332

333

def setup(self, trainer: Trainer) -> None:

334

"""Set up the strategy with trainer."""

335

336

def teardown(self) -> None:

337

"""Clean up the strategy."""

338

339

class ParallelStrategy(Strategy):

340

def __init__(

341

self,

342

accelerator: Optional[Accelerator] = None,

343

parallel_devices: Optional[List[torch.device]] = None,

344

cluster_environment: Optional[ClusterEnvironment] = None,

345

checkpoint_io: Optional[CheckpointIO] = None,

346

precision_plugin: Optional[PrecisionPlugin] = None

347

):

348

"""

349

Base parallel strategy class.

350

351

Args:

352

accelerator: Hardware accelerator

353

parallel_devices: List of devices for parallel training

354

cluster_environment: Cluster configuration

355

checkpoint_io: Checkpoint I/O plugin

356

precision_plugin: Precision plugin

357

"""

358

359

@property

360

def global_rank(self) -> int:

361

"""Global rank of the current process."""

362

363

@property

364

def local_rank(self) -> int:

365

"""Local rank of the current process."""

366

367

@property

368

def world_size(self) -> int:

369

"""Total number of processes."""

370

371

def all_gather(self, tensor: torch.Tensor, sync_grads: bool = False) -> torch.Tensor:

372

"""Gather tensor from all processes."""

373

374

def all_reduce(self, tensor: torch.Tensor, reduce_op: str = "mean") -> torch.Tensor:

375

"""Reduce tensor across all processes."""

376

377

def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:

378

"""Broadcast tensor from source to all processes."""

379

380

def barrier(self, name: Optional[str] = None) -> None:

381

"""Synchronize all processes."""

382

```

383

384

## Usage Examples

385

386

### Basic Strategy Usage

387

388

```python

389

from lightning import Trainer

390

391

# Use DDP strategy

392

trainer = Trainer(

393

accelerator="gpu",

394

devices=4,

395

strategy="ddp"

396

)

397

398

# Use FSDP strategy

399

trainer = Trainer(

400

accelerator="gpu",

401

devices=8,

402

strategy="fsdp"

403

)

404

```

405

406

### Advanced Strategy Configuration

407

408

```python

409

from lightning import Trainer

410

from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy

411

from datetime import timedelta

412

413

# Configure DDP with custom settings

414

ddp_strategy = DDPStrategy(

415

process_group_backend="nccl",

416

timeout=timedelta(seconds=1800),

417

start_method="spawn"

418

)

419

420

trainer = Trainer(

421

accelerator="gpu",

422

devices=4,

423

strategy=ddp_strategy,

424

precision="16-mixed"

425

)

426

427

# Configure FSDP with CPU offloading

428

from torch.distributed.fsdp import CPUOffload, ShardingStrategy

429

430

fsdp_strategy = FSDPStrategy(

431

sharding_strategy=ShardingStrategy.FULL_SHARD,

432

cpu_offload=CPUOffload(offload_params=True),

433

mixed_precision=None, # Let Lightning handle precision

434

auto_wrap_policy=None # Use default wrapping

435

)

436

437

trainer = Trainer(

438

accelerator="gpu",

439

devices=8,

440

strategy=fsdp_strategy,

441

precision="bf16-mixed"

442

)

443

```

444

445

### DeepSpeed Configuration

446

447

```python

448

from lightning import Trainer

449

from lightning.pytorch.strategies import DeepSpeedStrategy

450

451

# DeepSpeed ZeRO Stage 3 with offloading

452

deepspeed_strategy = DeepSpeedStrategy(

453

stage=3,

454

offload_optimizer=True,

455

offload_parameters=True,

456

remote_device="nvme",

457

nvme_path="/local_nvme"

458

)

459

460

trainer = Trainer(

461

accelerator="gpu",

462

devices=8,

463

strategy=deepspeed_strategy,

464

precision="16-mixed"

465

)

466

467

# DeepSpeed with custom config file

468

trainer = Trainer(

469

accelerator="gpu",

470

devices=8,

471

strategy=DeepSpeedStrategy(config="deepspeed_config.json"),

472

precision="16-mixed"

473

)

474

```