or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

big-modeling.mdcli-commands.mdconfiguration.mdcore-training.mddistributed-operations.mdindex.mdutilities.md

core-training.mddocs/

0

# Core Training

1

2

The main Accelerator class and essential training functionality that forms the foundation of Accelerate's distributed training capabilities. This includes mixed precision support, gradient accumulation, device management, and basic distributed operations.

3

4

## Capabilities

5

6

### Accelerator Class

7

8

The central orchestrator for distributed training that handles hardware detection, mixed precision setup, and training component preparation.

9

10

```python { .api }

11

class Accelerator:

12

"""

13

Main class for coordinating distributed training and mixed precision.

14

15

Handles device placement, distributed backend setup, mixed precision

16

configuration, and provides training utilities.

17

"""

18

19

def __init__(

20

self,

21

device_placement: bool = True,

22

split_batches: bool = False,

23

mixed_precision: str | None = None,

24

gradient_accumulation_steps: int = 1,

25

cpu: bool = False,

26

dataloader_config: DataLoaderConfiguration | None = None,

27

deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,

28

fsdp_plugin: FullyShardedDataParallelPlugin | None = None,

29

megatron_lm_plugin: MegatronLMPlugin | None = None,

30

rng_types: list[str] | None = None,

31

log_with: str | list[str] | None = None,

32

project_dir: str | None = None,

33

project_config: ProjectConfiguration | None = None,

34

gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,

35

step_scheduler_with_optimizer: bool = True,

36

kwargs_handlers: list[KwargsHandler] | None = None,

37

dynamo_backend: str | None = None,

38

dynamo_plugin: TorchDynamoPlugin | None = None,

39

parallelism_config: ParallelismConfig | None = None

40

):

41

"""

42

Initialize Accelerator with training configuration.

43

44

Parameters:

45

- device_placement: Whether to automatically place tensors on correct device

46

- split_batches: Whether to split batches across processes

47

- mixed_precision: Mixed precision mode ("no", "fp16", "bf16", "fp8")

48

- gradient_accumulation_steps: Number of steps to accumulate gradients

49

- cpu: Force CPU usage even if GPU available

50

- dataloader_config: DataLoader behavior configuration

51

- deepspeed_plugin: DeepSpeed configuration plugin (single or per-model dict)

52

- fsdp_plugin: FSDP configuration plugin

53

- megatron_lm_plugin: Megatron-LM configuration plugin

54

- rng_types: Random number generator types to synchronize

55

- log_with: Experiment tracking backends to use

56

- project_dir: Directory for project outputs

57

- project_config: Project and logging configuration

58

- gradient_accumulation_plugin: Gradient accumulation configuration

59

- step_scheduler_with_optimizer: Whether to step scheduler with optimizer

60

- kwargs_handlers: Additional configuration handlers

61

- dynamo_backend: Backend for torch.compile optimization

62

- dynamo_plugin: Torch Dynamo configuration plugin

63

- parallelism_config: Parallelism configuration object

64

"""

65

```

66

67

### Training Preparation

68

69

Methods for preparing models, optimizers, and data loaders for distributed training.

70

71

```python { .api }

72

def prepare(self, *args):

73

"""

74

Prepare models, optimizers, dataloaders for distributed training.

75

76

Automatically wraps objects for the current distributed setup and

77

applies mixed precision, device placement, and other configurations.

78

79

Parameters:

80

- *args: Models, optimizers, dataloaders, schedulers to prepare

81

82

Returns:

83

Tuple of prepared objects in same order as input

84

"""

85

86

def prepare_model(self, model: torch.nn.Module, device_placement: bool | None = None):

87

"""

88

Prepare a single model for distributed training.

89

90

Parameters:

91

- model: PyTorch model to prepare

92

- device_placement: Override default device placement behavior

93

94

Returns:

95

Prepared model wrapped for distributed training

96

"""

97

98

def prepare_optimizer(self, optimizer: torch.optim.Optimizer):

99

"""

100

Prepare optimizer for distributed training.

101

102

Parameters:

103

- optimizer: PyTorch optimizer to prepare

104

105

Returns:

106

Wrapped optimizer for distributed training

107

"""

108

109

def prepare_data_loader(

110

self,

111

data_loader: torch.utils.data.DataLoader,

112

device_placement: bool | None = None

113

):

114

"""

115

Prepare DataLoader for distributed training.

116

117

Parameters:

118

- data_loader: PyTorch DataLoader to prepare

119

- device_placement: Override default device placement

120

121

Returns:

122

DataLoader configured for distributed training

123

"""

124

125

def prepare_scheduler(self, scheduler):

126

"""

127

Prepare learning rate scheduler for distributed training.

128

129

Parameters:

130

- scheduler: PyTorch scheduler to prepare

131

132

Returns:

133

Wrapped scheduler for distributed training

134

"""

135

```

136

137

### Training Operations

138

139

Core training operations including backward pass, gradient clipping, and model unwrapping.

140

141

```python { .api }

142

def backward(self, loss: torch.Tensor, **kwargs):

143

"""

144

Perform backward pass with automatic mixed precision scaling.

145

146

Parameters:

147

- loss: Loss tensor to compute gradients from

148

- **kwargs: Additional arguments passed to loss.backward()

149

"""

150

151

def clip_grad_norm_(

152

self,

153

parameters,

154

max_norm: float,

155

norm_type: float = 2.0

156

):

157

"""

158

Clip gradient norm across all processes.

159

160

Parameters:

161

- parameters: Model parameters or parameter groups

162

- max_norm: Maximum norm of gradients

163

- norm_type: Type of norm to compute (default: 2.0)

164

165

Returns:

166

Total norm of parameters (viewed as single vector)

167

"""

168

169

def clip_grad_value_(self, parameters, clip_value: float):

170

"""

171

Clip gradient values to specified range.

172

173

Parameters:

174

- parameters: Model parameters to clip

175

- clip_value: Maximum absolute value for gradients

176

"""

177

178

def unwrap_model(self, model: torch.nn.Module, keep_fp32_wrapper: bool = True):

179

"""

180

Extract original model from distributed training wrappers.

181

182

Parameters:

183

- model: Wrapped model from prepare()

184

- keep_fp32_wrapper: Whether to keep mixed precision wrapper

185

186

Returns:

187

Original unwrapped model

188

"""

189

```

190

191

### Distributed Communication

192

193

Basic distributed operations for gathering, reducing, and broadcasting tensors.

194

195

```python { .api }

196

def gather(self, tensor: torch.Tensor):

197

"""

198

Gather tensor from all processes.

199

200

Parameters:

201

- tensor: Tensor to gather across processes

202

203

Returns:

204

Concatenated tensor from all processes (on main process only)

205

"""

206

207

def gather_for_metrics(self, input_data):

208

"""

209

Gather data from all processes for metrics computation.

210

211

Automatically handles padding for uneven batch sizes.

212

213

Parameters:

214

- input_data: Data to gather (tensors, lists, dicts)

215

216

Returns:

217

Gathered data from all processes

218

"""

219

220

def reduce(self, tensor: torch.Tensor, reduction: str = "mean"):

221

"""

222

Reduce tensor across all processes.

223

224

Parameters:

225

- tensor: Tensor to reduce

226

- reduction: Reduction operation ("mean", "sum")

227

228

Returns:

229

Reduced tensor

230

"""

231

232

def pad_across_processes(self, tensor: torch.Tensor, dim: int = 0, pad_index: int = 0):

233

"""

234

Pad tensor to same size across all processes.

235

236

Parameters:

237

- tensor: Tensor to pad

238

- dim: Dimension to pad along

239

- pad_index: Value to use for padding

240

241

Returns:

242

Padded tensor

243

"""

244

```

245

246

### Context Managers

247

248

Context managers for controlling training behavior and process synchronization.

249

250

```python { .api }

251

def accumulate(self, *models):

252

"""

253

Context manager for gradient accumulation.

254

255

Automatically handles gradient synchronization based on

256

gradient_accumulation_steps configuration.

257

258

Parameters:

259

- *models: Models to control gradient synchronization for

260

"""

261

262

def no_sync(self, *models):

263

"""

264

Context manager to disable gradient synchronization.

265

266

Parameters:

267

- *models: Models to disable synchronization for

268

"""

269

270

def main_process_first(self):

271

"""

272

Context manager to run code on main process first.

273

274

Ensures main process completes before other processes continue.

275

Useful for dataset preprocessing, model downloading, etc.

276

"""

277

278

def local_main_process_first(self):

279

"""

280

Context manager to run code on local main process first.

281

282

Similar to main_process_first but per-node instead of global.

283

"""

284

285

def autocast(self, cache_enabled: bool | None = None):

286

"""

287

Context manager for mixed precision autocast.

288

289

Parameters:

290

- cache_enabled: Whether to enable autocast cache

291

292

Returns:

293

Autocast context manager configured for current precision

294

"""

295

```

296

297

### Process Control and Utilities

298

299

Methods for process management, synchronization, and training control.

300

301

```python { .api }

302

def wait_for_everyone(self):

303

"""

304

Synchronization barrier - wait for all processes to reach this point.

305

"""

306

307

def print(self, *args, **kwargs):

308

"""

309

Print only on the main process.

310

311

Parameters:

312

- *args: Arguments to print

313

- **kwargs: Keyword arguments for print function

314

"""

315

316

def split_between_processes(self, inputs, apply_padding: bool = False):

317

"""

318

Split inputs between processes for distributed processing.

319

320

Parameters:

321

- inputs: Data to split between processes

322

- apply_padding: Whether to pad to equal sizes

323

324

Returns:

325

Portion of inputs for current process

326

"""

327

328

def free_memory(self):

329

"""

330

Free memory by clearing internal caches and calling garbage collection.

331

"""

332

333

def clear(self):

334

"""

335

Reset Accelerator to initial state and free memory.

336

"""

337

338

def skip_first_batches(self, dataloader, num_batches: int):

339

"""

340

Skip the first num_batches in a DataLoader.

341

342

Parameters:

343

- dataloader: DataLoader to skip batches from

344

- num_batches: Number of batches to skip

345

346

Returns:

347

DataLoader starting from the specified batch

348

"""

349

350

def verify_device_map(self, model: torch.nn.Module):

351

"""

352

Verify that the device map is valid for the given model.

353

354

Parameters:

355

- model: Model to verify device map for

356

"""

357

358

def lomo_backward(self, loss: torch.Tensor, learning_rate: float):

359

"""

360

Perform LOMO (Low-Memory Optimization) backward pass.

361

362

Parameters:

363

- loss: Loss tensor to compute gradients from

364

- learning_rate: Learning rate for LOMO optimizer

365

"""

366

367

def set_trigger(self):

368

"""

369

Set trigger for manual gradient synchronization control.

370

"""

371

372

def check_trigger(self):

373

"""

374

Check if gradient synchronization trigger is set.

375

376

Returns:

377

bool: Whether trigger is set

378

"""

379

```

380

381

### Properties

382

383

Key properties providing information about the training environment.

384

385

```python { .api }

386

@property

387

def device(self) -> torch.device:

388

"""Current device for this process."""

389

390

@property

391

def state(self) -> PartialState:

392

"""Access to the underlying PartialState."""

393

394

@property

395

def is_main_process(self) -> bool:

396

"""Whether this is the main process (rank 0)."""

397

398

@property

399

def is_local_main_process(self) -> bool:

400

"""Whether this is the local main process on this node."""

401

402

@property

403

def process_index(self) -> int:

404

"""Global process index/rank."""

405

406

@property

407

def local_process_index(self) -> int:

408

"""Local process index on this node."""

409

410

@property

411

def num_processes(self) -> int:

412

"""Total number of processes."""

413

414

@property

415

def distributed_type(self) -> DistributedType:

416

"""Type of distributed training backend being used."""

417

418

@property

419

def mixed_precision(self) -> str:

420

"""Mixed precision mode being used."""

421

422

@property

423

def use_distributed(self) -> bool:

424

"""Whether distributed training is being used."""

425

426

@property

427

def should_save_model(self) -> bool:

428

"""Whether this process should save the model."""

429

430

@property

431

def tensor_parallel_rank(self) -> int:

432

"""Tensor parallelism rank for this process."""

433

434

@property

435

def pipeline_parallel_rank(self) -> int:

436

"""Pipeline parallelism rank for this process."""

437

438

@property

439

def context_parallel_rank(self) -> int:

440

"""Context parallelism rank for this process."""

441

442

@property

443

def data_parallel_rank(self) -> int:

444

"""Data parallelism rank for this process."""

445

446

@property

447

def fp8_backend(self) -> str | None:

448

"""FP8 backend being used."""

449

450

@property

451

def is_fsdp2(self) -> bool:

452

"""Whether FSDP2 is being used."""

453

```

454

455

## Usage Examples

456

457

### Basic Training Setup

458

459

```python

460

from accelerate import Accelerator

461

import torch

462

import torch.nn as nn

463

464

# Initialize with mixed precision

465

accelerator = Accelerator(

466

mixed_precision="fp16",

467

gradient_accumulation_steps=4

468

)

469

470

# Create model and optimizer

471

model = nn.Linear(784, 10)

472

optimizer = torch.optim.Adam(model.parameters())

473

474

# Prepare for distributed training

475

model, optimizer = accelerator.prepare(model, optimizer)

476

477

# Training loop with gradient accumulation

478

for batch in dataloader:

479

with accelerator.accumulate(model):

480

outputs = model(batch['input'])

481

loss = criterion(outputs, batch['labels'])

482

accelerator.backward(loss)

483

optimizer.step()

484

optimizer.zero_grad()

485

```

486

487

### Advanced Configuration

488

489

```python

490

from accelerate import Accelerator, DataLoaderConfiguration, ProjectConfiguration

491

492

# Advanced configuration

493

dataloader_config = DataLoaderConfiguration(

494

split_batches=True,

495

dispatch_batches=False

496

)

497

498

project_config = ProjectConfiguration(

499

project_dir="./experiments",

500

automatic_checkpoint_naming=True,

501

total_limit=5

502

)

503

504

accelerator = Accelerator(

505

device_placement=True,

506

mixed_precision="bf16",

507

gradient_accumulation_steps=8,

508

dataloader_config=dataloader_config,

509

project_config=project_config

510

)

511

```