0
# Core Training
1
2
The main Accelerator class and essential training functionality that forms the foundation of Accelerate's distributed training capabilities. This includes mixed precision support, gradient accumulation, device management, and basic distributed operations.
3
4
## Capabilities
5
6
### Accelerator Class
7
8
The central orchestrator for distributed training that handles hardware detection, mixed precision setup, and training component preparation.
9
10
```python { .api }
11
class Accelerator:
12
"""
13
Main class for coordinating distributed training and mixed precision.
14
15
Handles device placement, distributed backend setup, mixed precision
16
configuration, and provides training utilities.
17
"""
18
19
def __init__(
20
self,
21
device_placement: bool = True,
22
split_batches: bool = False,
23
mixed_precision: str | None = None,
24
gradient_accumulation_steps: int = 1,
25
cpu: bool = False,
26
dataloader_config: DataLoaderConfiguration | None = None,
27
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
28
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
29
megatron_lm_plugin: MegatronLMPlugin | None = None,
30
rng_types: list[str] | None = None,
31
log_with: str | list[str] | None = None,
32
project_dir: str | None = None,
33
project_config: ProjectConfiguration | None = None,
34
gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
35
step_scheduler_with_optimizer: bool = True,
36
kwargs_handlers: list[KwargsHandler] | None = None,
37
dynamo_backend: str | None = None,
38
dynamo_plugin: TorchDynamoPlugin | None = None,
39
parallelism_config: ParallelismConfig | None = None
40
):
41
"""
42
Initialize Accelerator with training configuration.
43
44
Parameters:
45
- device_placement: Whether to automatically place tensors on correct device
46
- split_batches: Whether to split batches across processes
47
- mixed_precision: Mixed precision mode ("no", "fp16", "bf16", "fp8")
48
- gradient_accumulation_steps: Number of steps to accumulate gradients
49
- cpu: Force CPU usage even if GPU available
50
- dataloader_config: DataLoader behavior configuration
51
- deepspeed_plugin: DeepSpeed configuration plugin (single or per-model dict)
52
- fsdp_plugin: FSDP configuration plugin
53
- megatron_lm_plugin: Megatron-LM configuration plugin
54
- rng_types: Random number generator types to synchronize
55
- log_with: Experiment tracking backends to use
56
- project_dir: Directory for project outputs
57
- project_config: Project and logging configuration
58
- gradient_accumulation_plugin: Gradient accumulation configuration
59
- step_scheduler_with_optimizer: Whether to step scheduler with optimizer
60
- kwargs_handlers: Additional configuration handlers
61
- dynamo_backend: Backend for torch.compile optimization
62
- dynamo_plugin: Torch Dynamo configuration plugin
63
- parallelism_config: Parallelism configuration object
64
"""
65
```
66
67
### Training Preparation
68
69
Methods for preparing models, optimizers, and data loaders for distributed training.
70
71
```python { .api }
72
def prepare(self, *args):
73
"""
74
Prepare models, optimizers, dataloaders for distributed training.
75
76
Automatically wraps objects for the current distributed setup and
77
applies mixed precision, device placement, and other configurations.
78
79
Parameters:
80
- *args: Models, optimizers, dataloaders, schedulers to prepare
81
82
Returns:
83
Tuple of prepared objects in same order as input
84
"""
85
86
def prepare_model(self, model: torch.nn.Module, device_placement: bool | None = None):
87
"""
88
Prepare a single model for distributed training.
89
90
Parameters:
91
- model: PyTorch model to prepare
92
- device_placement: Override default device placement behavior
93
94
Returns:
95
Prepared model wrapped for distributed training
96
"""
97
98
def prepare_optimizer(self, optimizer: torch.optim.Optimizer):
99
"""
100
Prepare optimizer for distributed training.
101
102
Parameters:
103
- optimizer: PyTorch optimizer to prepare
104
105
Returns:
106
Wrapped optimizer for distributed training
107
"""
108
109
def prepare_data_loader(
110
self,
111
data_loader: torch.utils.data.DataLoader,
112
device_placement: bool | None = None
113
):
114
"""
115
Prepare DataLoader for distributed training.
116
117
Parameters:
118
- data_loader: PyTorch DataLoader to prepare
119
- device_placement: Override default device placement
120
121
Returns:
122
DataLoader configured for distributed training
123
"""
124
125
def prepare_scheduler(self, scheduler):
126
"""
127
Prepare learning rate scheduler for distributed training.
128
129
Parameters:
130
- scheduler: PyTorch scheduler to prepare
131
132
Returns:
133
Wrapped scheduler for distributed training
134
"""
135
```
136
137
### Training Operations
138
139
Core training operations including backward pass, gradient clipping, and model unwrapping.
140
141
```python { .api }
142
def backward(self, loss: torch.Tensor, **kwargs):
143
"""
144
Perform backward pass with automatic mixed precision scaling.
145
146
Parameters:
147
- loss: Loss tensor to compute gradients from
148
- **kwargs: Additional arguments passed to loss.backward()
149
"""
150
151
def clip_grad_norm_(
152
self,
153
parameters,
154
max_norm: float,
155
norm_type: float = 2.0
156
):
157
"""
158
Clip gradient norm across all processes.
159
160
Parameters:
161
- parameters: Model parameters or parameter groups
162
- max_norm: Maximum norm of gradients
163
- norm_type: Type of norm to compute (default: 2.0)
164
165
Returns:
166
Total norm of parameters (viewed as single vector)
167
"""
168
169
def clip_grad_value_(self, parameters, clip_value: float):
170
"""
171
Clip gradient values to specified range.
172
173
Parameters:
174
- parameters: Model parameters to clip
175
- clip_value: Maximum absolute value for gradients
176
"""
177
178
def unwrap_model(self, model: torch.nn.Module, keep_fp32_wrapper: bool = True):
179
"""
180
Extract original model from distributed training wrappers.
181
182
Parameters:
183
- model: Wrapped model from prepare()
184
- keep_fp32_wrapper: Whether to keep mixed precision wrapper
185
186
Returns:
187
Original unwrapped model
188
"""
189
```
190
191
### Distributed Communication
192
193
Basic distributed operations for gathering, reducing, and broadcasting tensors.
194
195
```python { .api }
196
def gather(self, tensor: torch.Tensor):
197
"""
198
Gather tensor from all processes.
199
200
Parameters:
201
- tensor: Tensor to gather across processes
202
203
Returns:
204
Concatenated tensor from all processes (on main process only)
205
"""
206
207
def gather_for_metrics(self, input_data):
208
"""
209
Gather data from all processes for metrics computation.
210
211
Automatically handles padding for uneven batch sizes.
212
213
Parameters:
214
- input_data: Data to gather (tensors, lists, dicts)
215
216
Returns:
217
Gathered data from all processes
218
"""
219
220
def reduce(self, tensor: torch.Tensor, reduction: str = "mean"):
221
"""
222
Reduce tensor across all processes.
223
224
Parameters:
225
- tensor: Tensor to reduce
226
- reduction: Reduction operation ("mean", "sum")
227
228
Returns:
229
Reduced tensor
230
"""
231
232
def pad_across_processes(self, tensor: torch.Tensor, dim: int = 0, pad_index: int = 0):
233
"""
234
Pad tensor to same size across all processes.
235
236
Parameters:
237
- tensor: Tensor to pad
238
- dim: Dimension to pad along
239
- pad_index: Value to use for padding
240
241
Returns:
242
Padded tensor
243
"""
244
```
245
246
### Context Managers
247
248
Context managers for controlling training behavior and process synchronization.
249
250
```python { .api }
251
def accumulate(self, *models):
252
"""
253
Context manager for gradient accumulation.
254
255
Automatically handles gradient synchronization based on
256
gradient_accumulation_steps configuration.
257
258
Parameters:
259
- *models: Models to control gradient synchronization for
260
"""
261
262
def no_sync(self, *models):
263
"""
264
Context manager to disable gradient synchronization.
265
266
Parameters:
267
- *models: Models to disable synchronization for
268
"""
269
270
def main_process_first(self):
271
"""
272
Context manager to run code on main process first.
273
274
Ensures main process completes before other processes continue.
275
Useful for dataset preprocessing, model downloading, etc.
276
"""
277
278
def local_main_process_first(self):
279
"""
280
Context manager to run code on local main process first.
281
282
Similar to main_process_first but per-node instead of global.
283
"""
284
285
def autocast(self, cache_enabled: bool | None = None):
286
"""
287
Context manager for mixed precision autocast.
288
289
Parameters:
290
- cache_enabled: Whether to enable autocast cache
291
292
Returns:
293
Autocast context manager configured for current precision
294
"""
295
```
296
297
### Process Control and Utilities
298
299
Methods for process management, synchronization, and training control.
300
301
```python { .api }
302
def wait_for_everyone(self):
303
"""
304
Synchronization barrier - wait for all processes to reach this point.
305
"""
306
307
def print(self, *args, **kwargs):
308
"""
309
Print only on the main process.
310
311
Parameters:
312
- *args: Arguments to print
313
- **kwargs: Keyword arguments for print function
314
"""
315
316
def split_between_processes(self, inputs, apply_padding: bool = False):
317
"""
318
Split inputs between processes for distributed processing.
319
320
Parameters:
321
- inputs: Data to split between processes
322
- apply_padding: Whether to pad to equal sizes
323
324
Returns:
325
Portion of inputs for current process
326
"""
327
328
def free_memory(self):
329
"""
330
Free memory by clearing internal caches and calling garbage collection.
331
"""
332
333
def clear(self):
334
"""
335
Reset Accelerator to initial state and free memory.
336
"""
337
338
def skip_first_batches(self, dataloader, num_batches: int):
339
"""
340
Skip the first num_batches in a DataLoader.
341
342
Parameters:
343
- dataloader: DataLoader to skip batches from
344
- num_batches: Number of batches to skip
345
346
Returns:
347
DataLoader starting from the specified batch
348
"""
349
350
def verify_device_map(self, model: torch.nn.Module):
351
"""
352
Verify that the device map is valid for the given model.
353
354
Parameters:
355
- model: Model to verify device map for
356
"""
357
358
def lomo_backward(self, loss: torch.Tensor, learning_rate: float):
359
"""
360
Perform LOMO (Low-Memory Optimization) backward pass.
361
362
Parameters:
363
- loss: Loss tensor to compute gradients from
364
- learning_rate: Learning rate for LOMO optimizer
365
"""
366
367
def set_trigger(self):
368
"""
369
Set trigger for manual gradient synchronization control.
370
"""
371
372
def check_trigger(self):
373
"""
374
Check if gradient synchronization trigger is set.
375
376
Returns:
377
bool: Whether trigger is set
378
"""
379
```
380
381
### Properties
382
383
Key properties providing information about the training environment.
384
385
```python { .api }
386
@property
387
def device(self) -> torch.device:
388
"""Current device for this process."""
389
390
@property
391
def state(self) -> PartialState:
392
"""Access to the underlying PartialState."""
393
394
@property
395
def is_main_process(self) -> bool:
396
"""Whether this is the main process (rank 0)."""
397
398
@property
399
def is_local_main_process(self) -> bool:
400
"""Whether this is the local main process on this node."""
401
402
@property
403
def process_index(self) -> int:
404
"""Global process index/rank."""
405
406
@property
407
def local_process_index(self) -> int:
408
"""Local process index on this node."""
409
410
@property
411
def num_processes(self) -> int:
412
"""Total number of processes."""
413
414
@property
415
def distributed_type(self) -> DistributedType:
416
"""Type of distributed training backend being used."""
417
418
@property
419
def mixed_precision(self) -> str:
420
"""Mixed precision mode being used."""
421
422
@property
423
def use_distributed(self) -> bool:
424
"""Whether distributed training is being used."""
425
426
@property
427
def should_save_model(self) -> bool:
428
"""Whether this process should save the model."""
429
430
@property
431
def tensor_parallel_rank(self) -> int:
432
"""Tensor parallelism rank for this process."""
433
434
@property
435
def pipeline_parallel_rank(self) -> int:
436
"""Pipeline parallelism rank for this process."""
437
438
@property
439
def context_parallel_rank(self) -> int:
440
"""Context parallelism rank for this process."""
441
442
@property
443
def data_parallel_rank(self) -> int:
444
"""Data parallelism rank for this process."""
445
446
@property
447
def fp8_backend(self) -> str | None:
448
"""FP8 backend being used."""
449
450
@property
451
def is_fsdp2(self) -> bool:
452
"""Whether FSDP2 is being used."""
453
```
454
455
## Usage Examples
456
457
### Basic Training Setup
458
459
```python
460
from accelerate import Accelerator
461
import torch
462
import torch.nn as nn
463
464
# Initialize with mixed precision
465
accelerator = Accelerator(
466
mixed_precision="fp16",
467
gradient_accumulation_steps=4
468
)
469
470
# Create model and optimizer
471
model = nn.Linear(784, 10)
472
optimizer = torch.optim.Adam(model.parameters())
473
474
# Prepare for distributed training
475
model, optimizer = accelerator.prepare(model, optimizer)
476
477
# Training loop with gradient accumulation
478
for batch in dataloader:
479
with accelerator.accumulate(model):
480
outputs = model(batch['input'])
481
loss = criterion(outputs, batch['labels'])
482
accelerator.backward(loss)
483
optimizer.step()
484
optimizer.zero_grad()
485
```
486
487
### Advanced Configuration
488
489
```python
490
from accelerate import Accelerator, DataLoaderConfiguration, ProjectConfiguration
491
492
# Advanced configuration
493
dataloader_config = DataLoaderConfiguration(
494
split_batches=True,
495
dispatch_batches=False
496
)
497
498
project_config = ProjectConfiguration(
499
project_dir="./experiments",
500
automatic_checkpoint_naming=True,
501
total_limit=5
502
)
503
504
accelerator = Accelerator(
505
device_placement=True,
506
mixed_precision="bf16",
507
gradient_accumulation_steps=8,
508
dataloader_config=dataloader_config,
509
project_config=project_config
510
)
511
```