0
# Configuration and Plugins
1
2
Configuration classes and plugins for customizing distributed training behavior, including DeepSpeed integration, FSDP configuration, mixed precision settings, and other advanced training optimizations.
3
4
## Capabilities
5
6
### Core Configuration Classes
7
8
Base configuration objects for controlling distributed training behavior.
9
10
```python { .api }
11
class DataLoaderConfiguration:
12
"""
13
Configuration for DataLoader behavior in distributed training.
14
15
Controls how data is distributed and processed across multiple processes.
16
"""
17
18
def __init__(
19
self,
20
split_batches: bool = False,
21
dispatch_batches: bool | None = None,
22
even_batches: bool = True,
23
use_seedable_sampler: bool = False,
24
use_configured_sampler: bool = False,
25
non_blocking: bool = False,
26
gradient_accumulation_kwargs: dict | None = None
27
):
28
"""
29
Initialize DataLoader configuration.
30
31
Parameters:
32
- split_batches: Whether to split batches across processes
33
- dispatch_batches: Whether to dispatch batches to processes
34
- even_batches: Ensure all processes get same number of batches
35
- use_seedable_sampler: Use seedable sampler for reproducibility
36
- use_configured_sampler: Use custom sampler configuration
37
- non_blocking: Use non-blocking data transfer
38
- gradient_accumulation_kwargs: Additional gradient accumulation settings
39
"""
40
41
class ProjectConfiguration:
42
"""
43
Configuration for project output directories and logging behavior.
44
"""
45
46
def __init__(
47
self,
48
project_dir: str = ".",
49
logging_dir: str | None = None,
50
automatic_checkpoint_naming: bool = False,
51
total_limit: int | None = None,
52
iteration_checkpoints: bool = False,
53
save_on_each_node: bool = False
54
):
55
"""
56
Initialize project configuration.
57
58
Parameters:
59
- project_dir: Root directory for project outputs
60
- logging_dir: Directory for log files (relative to project_dir)
61
- automatic_checkpoint_naming: Auto-generate checkpoint names
62
- total_limit: Maximum number of checkpoints to keep
63
- iteration_checkpoints: Save checkpoints by iteration number
64
- save_on_each_node: Save checkpoints on every node
65
"""
66
67
class GradientAccumulationPlugin:
68
"""
69
Plugin for configuring gradient accumulation behavior.
70
"""
71
72
def __init__(
73
self,
74
num_steps: int | None = None,
75
adjust_scheduler: bool = True,
76
sync_with_dataloader: bool = True
77
):
78
"""
79
Initialize gradient accumulation plugin.
80
81
Parameters:
82
- num_steps: Number of steps to accumulate gradients
83
- adjust_scheduler: Adjust scheduler for accumulation steps
84
- sync_with_dataloader: Sync accumulation with dataloader length
85
"""
86
```
87
88
### DeepSpeed Plugin
89
90
Configuration for DeepSpeed distributed training integration.
91
92
```python { .api }
93
class DeepSpeedPlugin:
94
"""
95
Plugin for DeepSpeed distributed training configuration.
96
97
Provides integration with Microsoft DeepSpeed for memory-efficient
98
training with ZeRO optimizer states, gradient partitioning, and
99
parameter offloading.
100
"""
101
102
def __init__(
103
self,
104
hf_ds_config: dict | str | None = None,
105
gradient_accumulation_steps: int | None = None,
106
gradient_clipping: float | None = None,
107
zero_stage: int | None = None,
108
is_train_batch_min: bool = True,
109
auto_wrap_policy: bool | None = None,
110
offload_optimizer_device: str | None = None,
111
offload_param_device: str | None = None,
112
offload_optimizer_nvme_path: str | None = None,
113
offload_param_nvme_path: str | None = None,
114
zero3_init_flag: bool | None = None,
115
zero3_save_16bit_model: bool | None = None,
116
**kwargs
117
):
118
"""
119
Initialize DeepSpeed plugin configuration.
120
121
Parameters:
122
- hf_ds_config: DeepSpeed configuration dict or path to config file
123
- gradient_accumulation_steps: Number of gradient accumulation steps
124
- gradient_clipping: Gradient clipping threshold
125
- zero_stage: ZeRO optimization stage (0, 1, 2, or 3)
126
- is_train_batch_min: Whether train_batch_size is minimum per device
127
- auto_wrap_policy: Automatic model wrapping policy
128
- offload_optimizer_device: Device for optimizer state offloading
129
- offload_param_device: Device for parameter offloading
130
- offload_optimizer_nvme_path: NVMe path for optimizer offloading
131
- offload_param_nvme_path: NVMe path for parameter offloading
132
- zero3_init_flag: Enable ZeRO-3 initialization optimizations
133
- zero3_save_16bit_model: Save model in 16-bit precision with ZeRO-3
134
"""
135
```
136
137
### FSDP Plugin
138
139
Configuration for PyTorch Fully Sharded Data Parallel training.
140
141
```python { .api }
142
class FullyShardedDataParallelPlugin:
143
"""
144
Plugin for PyTorch FSDP (Fully Sharded Data Parallel) configuration.
145
146
Enables memory-efficient training by sharding model parameters,
147
gradients, and optimizer states across multiple GPUs.
148
"""
149
150
def __init__(
151
self,
152
sharding_strategy: int | None = None,
153
backward_prefetch: int | None = None,
154
mixed_precision_policy: MixedPrecision | None = None,
155
auto_wrap_policy: ModuleWrapPolicy | None = None,
156
cpu_offload: CPUOffload | None = None,
157
ignored_modules: list[torch.nn.Module] | None = None,
158
state_dict_type: str | None = None,
159
state_dict_config: dict | None = None,
160
optim_state_dict_config: dict | None = None,
161
limit_all_gathers: bool = True,
162
use_orig_params: bool = True,
163
param_init_fn: callable | None = None,
164
sync_module_states: bool = True,
165
forward_prefetch: bool = False,
166
activation_checkpointing: bool = False
167
):
168
"""
169
Initialize FSDP plugin configuration.
170
171
Parameters:
172
- sharding_strategy: Parameter sharding strategy
173
- backward_prefetch: Backward pass prefetching strategy
174
- mixed_precision_policy: Mixed precision configuration
175
- auto_wrap_policy: Automatic module wrapping policy
176
- cpu_offload: CPU offloading configuration
177
- ignored_modules: Modules to exclude from FSDP wrapping
178
- state_dict_type: Type of state dict to use
179
- state_dict_config: State dict configuration
180
- optim_state_dict_config: Optimizer state dict configuration
181
- limit_all_gathers: Limit simultaneous all-gather operations
182
- use_orig_params: Use original parameter references
183
- param_init_fn: Custom parameter initialization function
184
- sync_module_states: Synchronize module states across ranks
185
- forward_prefetch: Enable forward pass prefetching
186
- activation_checkpointing: Enable activation checkpointing
187
"""
188
```
189
190
### Mixed Precision Configuration
191
192
Configuration classes for different mixed precision training modes.
193
194
```python { .api }
195
class AutocastKwargs:
196
"""
197
Configuration for PyTorch autocast mixed precision.
198
"""
199
200
def __init__(
201
self,
202
enabled: bool = True,
203
cache_enabled: bool | None = None
204
):
205
"""
206
Initialize autocast configuration.
207
208
Parameters:
209
- enabled: Whether to enable autocast
210
- cache_enabled: Whether to enable autocast caching
211
"""
212
213
class GradScalerKwargs:
214
"""
215
Configuration for gradient scaling in mixed precision training.
216
"""
217
218
def __init__(
219
self,
220
init_scale: float = 65536.0,
221
growth_factor: float = 2.0,
222
backoff_factor: float = 0.5,
223
growth_interval: int = 2000,
224
enabled: bool = True
225
):
226
"""
227
Initialize gradient scaler configuration.
228
229
Parameters:
230
- init_scale: Initial scaling factor
231
- growth_factor: Factor to multiply scale by when no overflow
232
- backoff_factor: Factor to multiply scale by when overflow detected
233
- growth_interval: Number of steps between scale increases
234
- enabled: Whether gradient scaling is enabled
235
"""
236
237
class FP8RecipeKwargs:
238
"""
239
Configuration for FP8 (8-bit floating point) training.
240
"""
241
242
def __init__(
243
self,
244
backend: str = "TE",
245
use_autocast: bool = True,
246
fp8_format: str = "HYBRID",
247
amax_history_len: int = 1024,
248
amax_compute_algo: str = "most_recent"
249
):
250
"""
251
Initialize FP8 training configuration.
252
253
Parameters:
254
- backend: FP8 backend to use ("TE" for Transformer Engine)
255
- use_autocast: Whether to use autocast with FP8
256
- fp8_format: FP8 format specification
257
- amax_history_len: Length of amax history for scaling
258
- amax_compute_algo: Algorithm for computing amax values
259
"""
260
```
261
262
### Torch Compilation and Optimization
263
264
Configuration for PyTorch compilation and optimization features.
265
266
```python { .api }
267
class TorchDynamoPlugin:
268
"""
269
Plugin for PyTorch Dynamo compilation configuration.
270
271
Enables torch.compile optimizations for faster training and inference.
272
"""
273
274
def __init__(
275
self,
276
backend: str = "inductor",
277
mode: str | None = None,
278
fullgraph: bool = False,
279
dynamic: bool | None = None,
280
options: dict | None = None,
281
disable: bool = False
282
):
283
"""
284
Initialize Torch Dynamo plugin.
285
286
Parameters:
287
- backend: Compilation backend ("inductor", "aot_eager", etc.)
288
- mode: Compilation mode ("default", "reduce-overhead", "max-autotune")
289
- fullgraph: Whether to require full graph compilation
290
- dynamic: Enable dynamic shape compilation
291
- options: Additional backend-specific options
292
- disable: Whether to disable compilation
293
"""
294
295
class TorchTensorParallelPlugin:
296
"""
297
Plugin for PyTorch tensor parallelism configuration.
298
"""
299
300
def __init__(
301
self,
302
tensor_parallel_degree: int = 1,
303
parallelize_plan: dict | None = None
304
):
305
"""
306
Initialize tensor parallel plugin.
307
308
Parameters:
309
- tensor_parallel_degree: Degree of tensor parallelism
310
- parallelize_plan: Custom parallelization plan
311
"""
312
```
313
314
### Quantization Configuration
315
316
Configuration classes for model quantization techniques.
317
318
```python { .api }
319
class BnbQuantizationConfig:
320
"""
321
Configuration for Bitsandbytes quantization.
322
323
Enables 4-bit and 8-bit quantization for memory-efficient training.
324
"""
325
326
def __init__(
327
self,
328
load_in_8bit: bool = False,
329
load_in_4bit: bool = False,
330
llm_int8_threshold: float = 6.0,
331
llm_int8_skip_modules: list[str] | None = None,
332
llm_int8_enable_fp32_cpu_offload: bool = False,
333
llm_int8_has_fp16_weight: bool = False,
334
bnb_4bit_compute_dtype: torch.dtype | None = None,
335
bnb_4bit_quant_type: str = "fp4",
336
bnb_4bit_use_double_quant: bool = False,
337
bnb_4bit_quant_storage: torch.dtype | None = None
338
):
339
"""
340
Initialize Bitsandbytes quantization configuration.
341
342
Parameters:
343
- load_in_8bit: Enable 8-bit quantization
344
- load_in_4bit: Enable 4-bit quantization
345
- llm_int8_threshold: Threshold for int8 quantization
346
- llm_int8_skip_modules: Modules to skip during quantization
347
- llm_int8_enable_fp32_cpu_offload: Enable FP32 CPU offloading
348
- llm_int8_has_fp16_weight: Whether model has FP16 weights
349
- bnb_4bit_compute_dtype: Compute dtype for 4-bit operations
350
- bnb_4bit_quant_type: 4-bit quantization type ("fp4" or "nf4")
351
- bnb_4bit_use_double_quant: Enable double quantization
352
- bnb_4bit_quant_storage: Storage dtype for quantized weights
353
"""
354
```
355
356
### Process Group Configuration
357
358
Configuration for distributed process group initialization.
359
360
```python { .api }
361
class InitProcessGroupKwargs:
362
"""
363
Configuration for distributed process group initialization.
364
"""
365
366
def __init__(
367
self,
368
init_method: str | None = None,
369
timeout: int = 1800,
370
backend: str | None = None
371
):
372
"""
373
Initialize process group configuration.
374
375
Parameters:
376
- init_method: Method for process group initialization
377
- timeout: Timeout for initialization (seconds)
378
- backend: Distributed backend to use
379
"""
380
381
class DistributedDataParallelKwargs:
382
"""
383
Configuration for PyTorch DistributedDataParallel wrapper.
384
"""
385
386
def __init__(
387
self,
388
dim: int = 0,
389
broadcast_buffers: bool = True,
390
bucket_cap_mb: int = 25,
391
find_unused_parameters: bool = False,
392
check_reduction: bool = False,
393
gradient_as_bucket_view: bool = False,
394
static_graph: bool = False,
395
comm_hook: callable | None = None,
396
comm_state_option: str | None = None
397
):
398
"""
399
Initialize DDP configuration.
400
401
Parameters:
402
- dim: Dimension for gradient reduction
403
- broadcast_buffers: Whether to broadcast buffers
404
- bucket_cap_mb: Bucket size for gradient communication (MB)
405
- find_unused_parameters: Find unused parameters during backward
406
- check_reduction: Check gradient reduction correctness
407
- gradient_as_bucket_view: Use gradient as bucket view
408
- static_graph: Whether computation graph is static
409
- comm_hook: Custom communication hook
410
- comm_state_option: Communication state configuration
411
"""
412
```
413
414
## Usage Examples
415
416
### Basic Configuration Setup
417
418
```python
419
from accelerate import (
420
Accelerator,
421
DataLoaderConfiguration,
422
ProjectConfiguration,
423
GradientAccumulationPlugin
424
)
425
426
# Configure data loading behavior
427
dataloader_config = DataLoaderConfiguration(
428
split_batches=True,
429
even_batches=True,
430
use_seedable_sampler=True
431
)
432
433
# Configure project outputs
434
project_config = ProjectConfiguration(
435
project_dir="./experiments",
436
logging_dir="logs",
437
automatic_checkpoint_naming=True,
438
total_limit=5
439
)
440
441
# Configure gradient accumulation
442
grad_accumulation = GradientAccumulationPlugin(
443
num_steps=4,
444
adjust_scheduler=True
445
)
446
447
# Initialize accelerator with configurations
448
accelerator = Accelerator(
449
mixed_precision="fp16",
450
dataloader_config=dataloader_config,
451
project_config=project_config,
452
gradient_accumulation_plugin=grad_accumulation
453
)
454
```
455
456
### DeepSpeed Configuration
457
458
```python
459
from accelerate import Accelerator, DeepSpeedPlugin
460
461
# Define DeepSpeed configuration
462
deepspeed_config = {
463
"train_batch_size": 16,
464
"gradient_accumulation_steps": 4,
465
"optimizer": {
466
"type": "Adam",
467
"params": {"lr": 1e-4}
468
},
469
"zero_optimization": {
470
"stage": 2,
471
"offload_optimizer": {"device": "cpu"},
472
"overlap_comm": True,
473
"contiguous_gradients": True
474
},
475
"fp16": {"enabled": True}
476
}
477
478
# Create DeepSpeed plugin
479
deepspeed_plugin = DeepSpeedPlugin(
480
hf_ds_config=deepspeed_config,
481
zero_stage=2,
482
gradient_accumulation_steps=4,
483
gradient_clipping=1.0
484
)
485
486
# Initialize accelerator with DeepSpeed
487
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
488
```
489
490
### FSDP Configuration
491
492
```python
493
from accelerate import Accelerator, FullyShardedDataParallelPlugin
494
from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch
495
496
# Configure FSDP plugin
497
fsdp_plugin = FullyShardedDataParallelPlugin(
498
sharding_strategy=ShardingStrategy.FULL_SHARD,
499
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
500
cpu_offload=None, # Keep on GPU
501
mixed_precision_policy=None, # Use default
502
auto_wrap_policy=None, # Auto-detect
503
limit_all_gathers=True,
504
use_orig_params=True,
505
sync_module_states=True
506
)
507
508
# Initialize with FSDP
509
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
510
```
511
512
### Advanced Mixed Precision Setup
513
514
```python
515
from accelerate import (
516
Accelerator,
517
AutocastKwargs,
518
GradScalerKwargs,
519
FP8RecipeKwargs
520
)
521
522
# Configure autocast behavior
523
autocast_kwargs = AutocastKwargs(
524
enabled=True,
525
cache_enabled=False
526
)
527
528
# Configure gradient scaling
529
scaler_kwargs = GradScalerKwargs(
530
init_scale=2**16,
531
growth_factor=2.0,
532
backoff_factor=0.5,
533
growth_interval=2000
534
)
535
536
# Configure FP8 training (if supported)
537
fp8_kwargs = FP8RecipeKwargs(
538
backend="TE",
539
use_autocast=True,
540
fp8_format="HYBRID"
541
)
542
543
# Initialize with advanced mixed precision
544
accelerator = Accelerator(
545
mixed_precision="fp16",
546
kwargs_handlers=[autocast_kwargs, scaler_kwargs]
547
)
548
```
549
550
### Quantization Configuration
551
552
```python
553
from accelerate import Accelerator, BnbQuantizationConfig
554
import torch
555
556
# Configure 4-bit quantization
557
bnb_config = BnbQuantizationConfig(
558
load_in_4bit=True,
559
bnb_4bit_compute_dtype=torch.bfloat16,
560
bnb_4bit_quant_type="nf4",
561
bnb_4bit_use_double_quant=True
562
)
563
564
# Note: Quantization is typically applied during model loading
565
# rather than through Accelerator initialization
566
model = AutoModelForCausalLM.from_pretrained(
567
"model_name",
568
quantization_config=bnb_config,
569
device_map="auto"
570
)
571
```