HuggingFace Accelerate is a PyTorch library that simplifies distributed and mixed-precision training by abstracting away the boilerplate code needed for multi-GPU, TPU, and mixed-precision setups.
npx @tessl/cli install tessl/pypi-accelerate@1.10.00
# HuggingFace Accelerate
1
2
HuggingFace Accelerate is a PyTorch library that simplifies distributed and mixed-precision training by abstracting away the boilerplate code needed for multi-GPU, TPU, and mixed-precision setups. It provides a thin wrapper around PyTorch that allows users to easily run their existing training scripts on any hardware configuration (single/multi-GPU, TPU, CPU) with minimal code changes - typically just 5 lines of code.
3
4
## Package Information
5
6
- **Package Name**: accelerate
7
- **Language**: Python
8
- **Installation**: `pip install accelerate`
9
10
## Core Imports
11
12
```python
13
from accelerate import Accelerator
14
```
15
16
For specific functionality:
17
18
```python
19
from accelerate import (
20
Accelerator,
21
PartialState,
22
ParallelismConfig,
23
cpu_offload,
24
cpu_offload_with_hook,
25
disk_offload,
26
dispatch_model,
27
init_empty_weights,
28
init_on_device,
29
load_checkpoint_and_dispatch,
30
skip_first_batches,
31
prepare_pippy,
32
debug_launcher,
33
notebook_launcher,
34
find_executable_batch_size,
35
infer_auto_device_map,
36
load_checkpoint_in_model,
37
synchronize_rng_states
38
)
39
```
40
41
## Basic Usage
42
43
```python
44
from accelerate import Accelerator
45
import torch
46
import torch.nn as nn
47
from torch.optim import AdamW
48
from torch.utils.data import DataLoader
49
50
# Initialize accelerator
51
accelerator = Accelerator(mixed_precision="fp16")
52
53
# Define model, optimizer, dataloader
54
model = nn.Linear(10, 1)
55
optimizer = AdamW(model.parameters(), lr=1e-4)
56
dataloader = DataLoader(dataset, batch_size=16)
57
58
# Prepare for distributed training (this is the key step)
59
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
60
61
# Training loop
62
for batch in dataloader:
63
optimizer.zero_grad()
64
65
with accelerator.accumulate(model):
66
outputs = model(batch['input'])
67
loss = torch.nn.functional.mse_loss(outputs, batch['target'])
68
accelerator.backward(loss)
69
optimizer.step()
70
71
# Save model
72
accelerator.save_model(model, "my_model")
73
```
74
75
## Architecture
76
77
Accelerate follows a modular design with these key components:
78
79
- **Accelerator**: Main orchestrator class that handles distributed training setup, mixed precision, and device management
80
- **PartialState**: Singleton containing environment information and process control utilities
81
- **Plugins**: Modular configuration system for different distributed backends (DeepSpeed, FSDP, etc.)
82
- **Big Modeling**: Utilities for handling models too large to fit in memory through device mapping and offloading
83
- **Launchers**: Tools for starting distributed training from different environments (notebooks, scripts, CLI)
84
85
This design allows Accelerate to work as a universal training wrapper that adapts to any hardware configuration while keeping the user's training code largely unchanged.
86
87
## Capabilities
88
89
### Core Training
90
91
The main Accelerator class and essential training functionality including mixed precision, gradient accumulation, and basic distributed operations.
92
93
```python { .api }
94
class Accelerator:
95
def __init__(
96
self,
97
device_placement: bool = True,
98
split_batches: bool = False,
99
mixed_precision: str | None = None,
100
gradient_accumulation_steps: int = 1,
101
cpu: bool = False,
102
dataloader_config: DataLoaderConfiguration | None = None,
103
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
104
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
105
megatron_lm_plugin: MegatronLMPlugin | None = None,
106
rng_types: list[str] | None = None,
107
log_with: str | list[str] | None = None,
108
project_dir: str | None = None,
109
project_config: ProjectConfiguration | None = None,
110
gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
111
step_scheduler_with_optimizer: bool = True,
112
kwargs_handlers: list[KwargsHandler] | None = None,
113
dynamo_backend: str | None = None,
114
dynamo_plugin: TorchDynamoPlugin | None = None,
115
parallelism_config: ParallelismConfig | None = None,
116
**kwargs
117
): ...
118
119
def prepare(self, *args): ...
120
def backward(self, loss, **kwargs): ...
121
def gather(self, tensor): ...
122
def save_model(self, model, save_directory: str, **kwargs): ...
123
```
124
125
[Core Training](./core-training.md)
126
127
### Big Modeling
128
129
Device management utilities for handling large models through CPU/disk offloading, device mapping, and efficient initialization strategies.
130
131
```python { .api }
132
def cpu_offload(
133
model: torch.nn.Module,
134
execution_device: torch.device | str | int | None = None,
135
offload_buffers: bool = False,
136
state_dict: dict[str, torch.Tensor] | None = None,
137
preload_module_classes: list[str] | None = None
138
): ...
139
140
def cpu_offload_with_hook(
141
model: torch.nn.Module,
142
execution_device: torch.device | str | int | None = None,
143
prev_module_hook: UserCpuOffloadHook | None = None
144
): ...
145
146
def disk_offload(
147
model: torch.nn.Module,
148
offload_dir: str | os.PathLike,
149
execution_device: torch.device | str | int | None = None,
150
offload_buffers: bool = False
151
): ...
152
153
def dispatch_model(
154
model: torch.nn.Module,
155
device_map: dict[str, torch.device | str | int] | None = None,
156
main_device: torch.device | str | int | None = None,
157
state_dict: dict[str, torch.Tensor] | None = None,
158
strict: bool = False,
159
preload_module_classes: list[str] | None = None
160
): ...
161
162
def init_empty_weights(include_buffers: bool = None): ...
163
164
def init_on_device(device: torch.device | str | int, include_buffers: bool = None): ...
165
166
def load_checkpoint_and_dispatch(
167
model: torch.nn.Module,
168
checkpoint: str | os.PathLike,
169
device_map: dict[str, torch.device | str | int] | None = None,
170
max_memory: dict[int | str, int | str] | None = None,
171
no_split_module_classes: list[str] | None = None,
172
strict: bool = False,
173
dtype: torch.dtype | None = None
174
): ...
175
```
176
177
[Big Modeling](./big-modeling.md)
178
179
### Distributed Operations
180
181
Low-level distributed communication primitives for gathering, broadcasting, and synchronizing data across processes.
182
183
```python { .api }
184
def broadcast(tensor: torch.Tensor, from_process: int = 0): ...
185
def gather(tensor: torch.Tensor): ...
186
def reduce(tensor: torch.Tensor, reduction: str = "mean"): ...
187
def wait_for_everyone(): ...
188
def synchronize_rng_states(rng_types: list[str] | None = None): ...
189
```
190
191
[Distributed Operations](./distributed-operations.md)
192
193
### Configuration and Plugins
194
195
Configuration classes and plugins for customizing distributed training behavior, including DeepSpeed, FSDP, and mixed precision settings.
196
197
```python { .api }
198
class DeepSpeedPlugin:
199
def __init__(
200
self,
201
hf_ds_config: dict | str | None = None,
202
gradient_accumulation_steps: int | None = None,
203
gradient_clipping: float | None = None,
204
zero_stage: int | None = None,
205
**kwargs
206
): ...
207
208
class FullyShardedDataParallelPlugin:
209
def __init__(
210
self,
211
sharding_strategy: int | None = None,
212
backward_prefetch: int | None = None,
213
mixed_precision_policy: MixedPrecision | None = None,
214
**kwargs
215
): ...
216
```
217
218
[Configuration](./configuration.md)
219
220
### Utilities
221
222
Memory management, checkpointing, model utilities, and various helper functions for training workflows.
223
224
```python { .api }
225
def find_executable_batch_size(function, starting_batch_size: int = 128): ...
226
def infer_auto_device_map(
227
model: torch.nn.Module,
228
max_memory: dict[int | str, int | str] | None = None,
229
no_split_module_classes: list[str] | None = None
230
): ...
231
def load_checkpoint_in_model(
232
model: torch.nn.Module,
233
checkpoint: str | os.PathLike,
234
device_map: dict[str, torch.device | str | int] | None = None
235
): ...
236
```
237
238
[Utilities](./utilities.md)
239
240
### CLI Commands
241
242
Command-line tools for configuration, launching distributed training, memory estimation, and environment management.
243
244
```bash { .api }
245
accelerate config # Interactive configuration setup
246
accelerate launch # Launch distributed training
247
accelerate env # Display environment information
248
accelerate estimate-memory # Estimate memory requirements
249
accelerate test # Test distributed setup
250
```
251
252
[CLI Commands](./cli-commands.md)
253
254
### Data Loading
255
256
DataLoader utilities for skipping batches and handling distributed data loading patterns.
257
258
```python { .api }
259
def skip_first_batches(dataloader: torch.utils.data.DataLoader, num_batches: int): ...
260
```
261
262
### Launchers
263
264
Tools for launching distributed training from different environments including notebooks and debugging scenarios.
265
266
```python { .api }
267
def notebook_launcher(
268
function,
269
args: tuple = (),
270
num_processes: int = None,
271
mixed_precision: str = "no",
272
use_port: str = "29500"
273
): ...
274
275
def debug_launcher(
276
function,
277
args: tuple = (),
278
num_processes: int = 2
279
): ...
280
```
281
282
### Inference
283
284
Pipeline parallelism utilities for large model inference.
285
286
```python { .api }
287
def prepare_pippy(
288
model: torch.nn.Module,
289
split_points: str | list[str] | None = None,
290
no_split_module_classes: list[str] | None = None
291
): ...
292
```
293
294
## Types
295
296
```python { .api }
297
class PartialState:
298
"""Singleton class containing distributed training state."""
299
device: torch.device
300
distributed_type: DistributedType
301
local_process_index: int
302
process_index: int
303
num_processes: int
304
is_main_process: bool
305
is_local_main_process: bool
306
307
def wait_for_everyone(self): ...
308
def split_between_processes(self, inputs, apply_padding: bool = False): ...
309
310
class DataLoaderConfiguration:
311
"""Configuration for DataLoader behavior in distributed training."""
312
split_batches: bool = False
313
dispatch_batches: bool | None = None
314
even_batches: bool = True
315
use_seedable_sampler: bool = False
316
317
class ProjectConfiguration:
318
"""Configuration for project output and logging."""
319
project_dir: str = "."
320
logging_dir: str | None = None
321
automatic_checkpoint_naming: bool = False
322
total_limit: int | None = None
323
iteration_checkpoints: bool = False
324
save_every_n_steps: int | None = None
325
326
from enum import Enum
327
328
class DistributedType(Enum):
329
"""Types of distributed training backends."""
330
NO = "NO"
331
MULTI_CPU = "MULTI_CPU"
332
MULTI_GPU = "MULTI_GPU"
333
MULTI_MLU = "MULTI_MLU"
334
MULTI_NPU = "MULTI_NPU"
335
MULTI_XPU = "MULTI_XPU"
336
DEEPSPEED = "DEEPSPEED"
337
FSDP = "FSDP"
338
339
class PrecisionType(Enum):
340
"""Mixed precision training types."""
341
NO = "no"
342
FP16 = "fp16"
343
BF16 = "bf16"
344
FP8 = "fp8"
345
346
class LoggerType(Enum):
347
"""Experiment tracking logger types."""
348
TENSORBOARD = "tensorboard"
349
WANDB = "wandb"
350
COMET_ML = "comet_ml"
351
MLFLOW = "mlflow"
352
AIM = "aim"
353
CLEARML = "clearml"
354
355
class GradientAccumulationPlugin:
356
"""Plugin for gradient accumulation configuration."""
357
num_steps: int
358
adjust_scheduler: bool = True
359
sync_with_dataloader: bool = True
360
361
class MegatronLMPlugin:
362
"""Plugin for Megatron-LM configuration."""
363
tp_degree: int = 1
364
pp_degree: int = 1
365
num_micro_batches: int = 1
366
sequence_parallelism: bool = False
367
recompute_activations: bool = False
368
use_distributed_optimizer: bool = False
369
370
class UserCpuOffloadHook:
371
"""Hook for managing CPU offloading behavior."""
372
def offload(self): ...
373
def remove(self): ...
374
375
class TorchDynamoPlugin:
376
"""Plugin for PyTorch Dynamo configuration."""
377
backend: str = "inductor"
378
mode: str | None = None
379
fullgraph: bool = False
380
dynamic: bool | None = None
381
options: dict | None = None
382
383
class KwargsHandler:
384
"""Base class for handling additional configuration arguments."""
385
pass
386
```