0
# Utilities
1
2
Memory management, checkpointing, model utilities, and various helper functions for training workflows. These utilities provide essential functionality for efficient training, model management, and system optimization.
3
4
## Capabilities
5
6
### Memory Management
7
8
Functions for optimizing memory usage during training and inference.
9
10
```python { .api }
11
def find_executable_batch_size(
12
function: callable,
13
starting_batch_size: int = 128
14
):
15
"""
16
Automatically find the largest executable batch size for a function.
17
18
Performs binary search to find the maximum batch size that doesn't
19
cause out-of-memory errors, useful for maximizing hardware utilization.
20
21
Parameters:
22
- function: Function to test with different batch sizes
23
- starting_batch_size: Initial batch size to try
24
25
Returns:
26
Largest batch size that executes successfully
27
"""
28
29
def release_memory(*objects):
30
"""
31
Release memory from specified objects and trigger garbage collection.
32
33
Parameters:
34
- *objects: Objects to delete and release memory from
35
"""
36
```
37
38
### Model Utilities
39
40
Functions for model introspection, manipulation, and memory analysis.
41
42
```python { .api }
43
def infer_auto_device_map(
44
model: torch.nn.Module,
45
max_memory: dict[int | str, int | str] | None = None,
46
no_split_module_classes: list[str] | None = None,
47
dtype: torch.dtype | str | None = None,
48
special_dtypes: dict[str, torch.dtype | str] | None = None,
49
verbose: bool = False
50
):
51
"""
52
Automatically infer optimal device mapping for a model.
53
54
Analyzes model size and available memory to determine the best
55
placement of layers across devices.
56
57
Parameters:
58
- model: Model to analyze
59
- max_memory: Maximum memory per device
60
- no_split_module_classes: Module classes that shouldn't be split
61
- dtype: Data type for memory calculations
62
- special_dtypes: Special dtypes for specific parameters
63
- verbose: Print detailed mapping information
64
65
Returns:
66
Dictionary mapping layer names to devices
67
"""
68
69
def get_balanced_memory(
70
model: torch.nn.Module,
71
max_memory: dict[int | str, int | str] | None = None,
72
no_split_module_classes: list[str] | None = None,
73
dtype: torch.dtype | None = None,
74
low_zero_memory: bool = False
75
):
76
"""
77
Calculate balanced memory distribution for model across devices.
78
79
Parameters:
80
- model: Model to analyze
81
- max_memory: Memory constraints per device
82
- no_split_module_classes: Modules to keep together
83
- dtype: Data type for calculations
84
- low_zero_memory: Use minimal memory for device 0
85
86
Returns:
87
Balanced memory allocation across devices
88
"""
89
90
def compute_module_sizes(
91
model: torch.nn.Module,
92
dtype: torch.dtype | None = None
93
):
94
"""
95
Compute memory size of each module in the model.
96
97
Parameters:
98
- model: Model to analyze
99
- dtype: Data type for size calculations
100
101
Returns:
102
Dictionary mapping module names to memory sizes in bytes
103
"""
104
105
def get_max_memory(max_memory: dict[int | str, int | str] | None = None):
106
"""
107
Get maximum available memory per device.
108
109
Parameters:
110
- max_memory: User-specified memory limits
111
112
Returns:
113
Dictionary of available memory per device
114
"""
115
116
def has_offloaded_params(model: torch.nn.Module):
117
"""
118
Check if model has any offloaded parameters.
119
120
Parameters:
121
- model: Model to check
122
123
Returns:
124
Boolean indicating presence of offloaded parameters
125
"""
126
```
127
128
### Checkpointing and State Management
129
130
Functions for saving and loading training state and model checkpoints.
131
132
```python { .api }
133
def save_accelerator_state(
134
output_dir: str | os.PathLike,
135
safe_serialization: bool = True
136
):
137
"""
138
Save complete Accelerator training state.
139
140
Saves model, optimizer, scheduler, and RNG states for complete
141
training resumption.
142
143
Parameters:
144
- output_dir: Directory to save state files
145
- safe_serialization: Use safetensors format when possible
146
"""
147
148
def load_accelerator_state(input_dir: str | os.PathLike):
149
"""
150
Load complete Accelerator training state.
151
152
Parameters:
153
- input_dir: Directory containing saved state files
154
"""
155
156
def save_custom_state(
157
obj,
158
path: str | os.PathLike,
159
process_index: int = 0,
160
scaler: callable | None = None
161
):
162
"""
163
Save custom object state with process coordination.
164
165
Parameters:
166
- obj: Object to save
167
- path: Path to save object
168
- process_index: Process responsible for saving
169
- scaler: Optional scaling function
170
"""
171
172
def load_custom_state(
173
path: str | os.PathLike,
174
process_index: int = 0,
175
scaler: callable | None = None
176
):
177
"""
178
Load custom object state.
179
180
Parameters:
181
- path: Path to load object from
182
- process_index: Process responsible for loading
183
- scaler: Optional scaling function
184
185
Returns:
186
Loaded object
187
"""
188
189
def load_checkpoint_in_model(
190
model: torch.nn.Module,
191
checkpoint: str | os.PathLike,
192
device_map: dict[str, torch.device | str | int] | None = None,
193
offload_folder: str | os.PathLike | None = None,
194
dtype: torch.dtype | None = None,
195
offload_state_dict: bool = False,
196
offload_buffers: bool = False,
197
keep_in_fp32_modules: list[str] | None = None,
198
strict: bool = False
199
):
200
"""
201
Load checkpoint into model with advanced options.
202
203
Parameters:
204
- model: Model to load checkpoint into
205
- checkpoint: Path to checkpoint file
206
- device_map: Device placement mapping
207
- offload_folder: Directory for offloaded weights
208
- dtype: Target data type
209
- offload_state_dict: Offload entire state dict
210
- offload_buffers: Offload buffer tensors
211
- keep_in_fp32_modules: Modules to keep in FP32
212
- strict: Strict checkpoint loading
213
214
Returns:
215
Tuple of (missing_keys, unexpected_keys)
216
"""
217
```
218
219
### Random State Management
220
221
Functions for managing random number generation across distributed processes.
222
223
```python { .api }
224
def set_seed(seed: int, device_specific: bool = False):
225
"""
226
Set random seed across all processes and libraries.
227
228
Sets seeds for PyTorch, NumPy, Python random, and other libraries
229
to ensure reproducible results across distributed training.
230
231
Parameters:
232
- seed: Random seed value
233
- device_specific: Use device-specific seeding for different results per device
234
"""
235
236
def synchronize_rng_states(
237
rng_types: list[str] | None = None,
238
generator: torch.Generator | None = None
239
):
240
"""
241
Synchronize random number generator states across processes.
242
243
Parameters:
244
- rng_types: Types of RNG to sync ("torch", "cuda", "xla")
245
- generator: Specific generator to synchronize
246
"""
247
248
def synchronize_rng_state(
249
rng_type: str | None = None,
250
generator: torch.Generator | None = None
251
):
252
"""
253
Synchronize single RNG state across processes.
254
255
Parameters:
256
- rng_type: Type of RNG to synchronize
257
- generator: Specific generator to use
258
"""
259
```
260
261
### Model Parameter Management
262
263
Functions for managing model parameters, tied weights, and device placement.
264
265
```python { .api }
266
def find_tied_parameters(model: torch.nn.Module):
267
"""
268
Find tied (shared) parameters in model.
269
270
Parameters:
271
- model: Model to analyze
272
273
Returns:
274
List of parameter groups that share the same tensor
275
"""
276
277
def check_tied_parameters_on_same_device(model: torch.nn.Module):
278
"""
279
Verify that tied parameters are on the same device.
280
281
Parameters:
282
- model: Model to check
283
284
Returns:
285
Boolean indicating if all tied parameters are properly placed
286
"""
287
288
def retie_parameters(
289
model: torch.nn.Module,
290
tied_params: list[list[str]]
291
):
292
"""
293
Re-establish parameter tying after model loading.
294
295
Parameters:
296
- model: Model with parameters to retie
297
- tied_params: List of parameter groups to tie together
298
"""
299
300
def set_module_tensor_to_device(
301
module: torch.nn.Module,
302
tensor_name: str,
303
device: torch.device | str | int,
304
value: torch.Tensor | None = None,
305
dtype: torch.dtype | None = None
306
):
307
"""
308
Set specific tensor in module to device with optional value/dtype.
309
310
Parameters:
311
- module: Module containing the tensor
312
- tensor_name: Name of tensor to modify
313
- device: Target device
314
- value: Optional new tensor value
315
- dtype: Optional target dtype
316
"""
317
318
def align_module_device(
319
module: torch.nn.Module,
320
execution_device: torch.device | str | int
321
):
322
"""
323
Align module device with execution device.
324
325
Parameters:
326
- module: Module to align
327
- execution_device: Target execution device
328
"""
329
```
330
331
### File I/O and Serialization
332
333
General-purpose functions for saving and loading objects with device awareness.
334
335
```python { .api }
336
def save(
337
obj,
338
path: str | os.PathLike,
339
save_on_each_node: bool = False,
340
safe_serialization: bool = False
341
):
342
"""
343
Save object with distributed training awareness.
344
345
Parameters:
346
- obj: Object to save
347
- path: Save path
348
- save_on_each_node: Save on each node instead of just main process
349
- safe_serialization: Use safetensors format when possible
350
"""
351
352
def load(
353
path: str | os.PathLike,
354
map_location: str | torch.device | None = None,
355
**kwargs
356
):
357
"""
358
Load object with device mapping support.
359
360
Parameters:
361
- path: Path to load from
362
- map_location: Device mapping for tensors
363
- **kwargs: Additional arguments for loading
364
365
Returns:
366
Loaded object
367
"""
368
369
def clean_state_dict_for_safetensors(state_dict: dict):
370
"""
371
Clean state dict for safetensors serialization.
372
373
Removes incompatible elements and prepares dict for safetensors format.
374
375
Parameters:
376
- state_dict: State dictionary to clean
377
378
Returns:
379
Cleaned state dictionary
380
"""
381
```
382
383
### Environment and Import Detection
384
385
Functions for detecting available libraries and hardware capabilities.
386
387
```python { .api }
388
def is_cuda_available():
389
"""Check if CUDA is available."""
390
391
def is_mps_available():
392
"""Check if Apple MPS is available."""
393
394
def is_xpu_available():
395
"""Check if Intel XPU is available."""
396
397
def is_hpu_available():
398
"""Check if Habana HPU is available."""
399
400
def is_npu_available():
401
"""Check if NPU is available."""
402
403
def is_deepspeed_available():
404
"""Check if DeepSpeed is available."""
405
406
def is_transformers_available():
407
"""Check if Transformers library is available."""
408
409
def is_datasets_available():
410
"""Check if Datasets library is available."""
411
412
def is_wandb_available():
413
"""Check if Weights & Biases is available."""
414
415
def is_tensorboard_available():
416
"""Check if TensorBoard is available."""
417
418
def is_comet_ml_available():
419
"""Check if Comet ML is available."""
420
421
def is_mlflow_available():
422
"""Check if MLflow is available."""
423
424
def is_bnb_available():
425
"""Check if Bitsandbytes is available."""
426
427
def is_4bit_bnb_available():
428
"""Check if 4-bit Bitsandbytes quantization is available."""
429
430
def is_8bit_bnb_available():
431
"""Check if 8-bit Bitsandbytes quantization is available."""
432
433
def is_torch_xla_available():
434
"""Check if Torch XLA is available."""
435
436
def is_rich_available():
437
"""Check if Rich formatting library is available."""
438
```
439
440
### System Utilities
441
442
General system and process management utilities.
443
444
```python { .api }
445
def wait_for_everyone():
446
"""
447
Global synchronization barrier across all processes.
448
"""
449
450
def extract_model_from_parallel(
451
model: torch.nn.Module,
452
keep_fp32_wrapper: bool = True
453
):
454
"""
455
Extract original model from parallel training wrappers.
456
457
Parameters:
458
- model: Wrapped model
459
- keep_fp32_wrapper: Whether to keep mixed precision wrapper
460
461
Returns:
462
Unwrapped model
463
"""
464
465
def merge_dicts(dict1: dict, dict2: dict):
466
"""
467
Merge two dictionaries recursively.
468
469
Parameters:
470
- dict1: First dictionary
471
- dict2: Second dictionary
472
473
Returns:
474
Merged dictionary
475
"""
476
477
def get_pretty_name(obj):
478
"""
479
Get human-readable name for object.
480
481
Parameters:
482
- obj: Object to get name for
483
484
Returns:
485
Pretty string representation
486
"""
487
488
def write_basic_config(
489
mixed_precision: str = "no",
490
save_location: str = "default"
491
):
492
"""
493
Write basic Accelerate configuration file.
494
495
Parameters:
496
- mixed_precision: Mixed precision mode
497
- save_location: Where to save config ("default" or custom path)
498
"""
499
500
def convert_bytes(size_bytes: int):
501
"""
502
Convert bytes to human-readable format.
503
504
Parameters:
505
- size_bytes: Size in bytes
506
507
Returns:
508
Human-readable size string (e.g., "1.5 GB")
509
"""
510
```
511
512
## Usage Examples
513
514
### Automatic Batch Size Finding
515
516
```python
517
from accelerate import find_executable_batch_size
518
import torch
519
520
def training_function(batch_size):
521
# Your training code here
522
model = MyModel()
523
optimizer = torch.optim.Adam(model.parameters())
524
525
# Simulate training step
526
for _ in range(10):
527
batch = torch.randn(batch_size, 784)
528
loss = model(batch).sum()
529
loss.backward()
530
optimizer.step()
531
optimizer.zero_grad()
532
533
# Find optimal batch size automatically
534
optimal_batch_size = find_executable_batch_size(training_function)
535
print(f"Optimal batch size: {optimal_batch_size}")
536
```
537
538
### Model Memory Analysis
539
540
```python
541
from accelerate import (
542
compute_module_sizes,
543
get_balanced_memory,
544
infer_auto_device_map
545
)
546
547
# Analyze model memory usage
548
module_sizes = compute_module_sizes(model, dtype=torch.float16)
549
print("Memory usage per module:")
550
for name, size in module_sizes.items():
551
print(f"{name}: {size / 1024**3:.2f} GB")
552
553
# Get balanced memory allocation
554
max_memory = {"0": "10GB", "1": "10GB", "cpu": "30GB"}
555
balanced_memory = get_balanced_memory(
556
model,
557
max_memory=max_memory,
558
no_split_module_classes=["LlamaDecoderLayer"]
559
)
560
561
# Infer optimal device mapping
562
device_map = infer_auto_device_map(
563
model,
564
max_memory=balanced_memory,
565
no_split_module_classes=["LlamaDecoderLayer"],
566
verbose=True
567
)
568
```
569
570
### Advanced Checkpointing
571
572
```python
573
from accelerate import (
574
save_accelerator_state,
575
load_accelerator_state,
576
save_custom_state,
577
load_custom_state
578
)
579
580
# Save complete training state
581
accelerator = Accelerator()
582
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
583
584
# After some training...
585
save_accelerator_state("./checkpoint-1000", safe_serialization=True)
586
587
# Save custom objects
588
training_metadata = {
589
"epoch": 5,
590
"best_loss": 0.1234,
591
"learning_rates": [0.001, 0.0005, 0.0001]
592
}
593
save_custom_state(training_metadata, "./checkpoint-1000/metadata.pkl")
594
595
# Later, load everything back
596
load_accelerator_state("./checkpoint-1000")
597
metadata = load_custom_state("./checkpoint-1000/metadata.pkl")
598
```
599
600
### Random State Management
601
602
```python
603
from accelerate import set_seed, synchronize_rng_states
604
605
# Set reproducible seed across all processes
606
set_seed(42, device_specific=False)
607
608
# Synchronize RNG states for consistency
609
synchronize_rng_states(["torch", "cuda", "numpy"])
610
611
# Training with consistent randomness
612
for epoch in range(num_epochs):
613
# All processes will generate the same random augmentations
614
for batch in dataloader:
615
augmented_batch = apply_random_augmentation(batch)
616
# ... training code
617
```
618
619
### Parameter Management
620
621
```python
622
from accelerate import (
623
find_tied_parameters,
624
check_tied_parameters_on_same_device,
625
retie_parameters
626
)
627
628
# Find tied parameters in model
629
tied_params = find_tied_parameters(model)
630
print("Tied parameter groups:", tied_params)
631
632
# Check if tied parameters are properly placed
633
if not check_tied_parameters_on_same_device(model):
634
print("Warning: Tied parameters are not on the same device!")
635
636
# Re-tie parameters after loading from checkpoint
637
retie_parameters(model, tied_params)
638
```
639
640
### System Integration
641
642
```python
643
from accelerate import (
644
is_cuda_available,
645
is_deepspeed_available,
646
write_basic_config,
647
convert_bytes
648
)
649
650
# Check system capabilities
651
print(f"CUDA available: {is_cuda_available()}")
652
print(f"DeepSpeed available: {is_deepspeed_available()}")
653
654
# Create basic configuration
655
if is_cuda_available():
656
write_basic_config(mixed_precision="fp16")
657
else:
658
write_basic_config(mixed_precision="no")
659
660
# Memory usage reporting
661
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
662
print(f"Model size: {convert_bytes(model_size)}")
663
```