0
# Big Modeling
1
2
Device management utilities for handling large models that exceed single device memory through CPU/disk offloading, automatic device mapping, and efficient initialization strategies. These utilities enable training and inference with models that would otherwise be impossible to run.
3
4
## Capabilities
5
6
### CPU Offloading
7
8
Functions for offloading model parameters to CPU memory when not in use, automatically moving them to GPU during forward/backward passes.
9
10
```python { .api }
11
def cpu_offload(
12
model: torch.nn.Module,
13
execution_device: torch.device | None = None,
14
offload_buffers: bool = False,
15
state_dict: dict[str, torch.Tensor] | None = None,
16
preload_module_classes: list[str] | None = None
17
):
18
"""
19
Offload model to CPU with automatic device management hooks.
20
21
Model parameters are moved to CPU and automatically transferred to
22
execution device during forward pass, then moved back to CPU.
23
24
Parameters:
25
- model: Model to offload to CPU
26
- execution_device: Device to use during computation (default: auto-detect)
27
- offload_buffers: Whether to also offload buffer tensors
28
- state_dict: Optional state dict to use for model parameters
29
- preload_module_classes: Module classes to preload on execution device
30
31
Returns:
32
Model with CPU offloading hooks attached
33
"""
34
35
def cpu_offload_with_hook(
36
model: torch.nn.Module,
37
execution_device: torch.device | str | int | None = None,
38
prev_module_hook: UserCpuOffloadHook | None = None
39
):
40
"""
41
Advanced CPU offloading with custom hook chaining.
42
43
Provides more control over offloading behavior and allows chaining
44
multiple offloading hooks for complex model architectures.
45
46
Parameters:
47
- model: Model to offload
48
- execution_device: Computation device
49
- prev_module_hook: Previous hook in the chain
50
51
Returns:
52
Tuple of (model, hook) for hook chaining
53
"""
54
```
55
56
### Disk Offloading
57
58
Functions for offloading model parameters to disk storage for extremely large models that exceed total system memory.
59
60
```python { .api }
61
def disk_offload(
62
model: torch.nn.Module,
63
offload_dir: str | os.PathLike,
64
execution_device: torch.device | str | int | None = None,
65
offload_buffers: bool = False
66
):
67
"""
68
Offload model parameters to disk storage.
69
70
Parameters are saved to disk and loaded on-demand during computation.
71
Slower than CPU offloading but enables handling arbitrarily large models.
72
73
Parameters:
74
- model: Model to offload to disk
75
- offload_dir: Directory to store offloaded parameters
76
- execution_device: Device for computation
77
- offload_buffers: Whether to offload buffer tensors
78
79
Returns:
80
Model with disk offloading hooks
81
"""
82
```
83
84
### Device Mapping and Dispatch
85
86
Functions for automatically distributing model layers across multiple devices based on memory constraints and performance considerations.
87
88
```python { .api }
89
def dispatch_model(
90
model: torch.nn.Module,
91
device_map: dict[str, torch.device | str | int] | None = None,
92
main_device: torch.device | str | int | None = None,
93
state_dict: dict[str, torch.Tensor] | None = None,
94
strict: bool = False,
95
preload_module_classes: list[str] | None = None
96
):
97
"""
98
Dispatch model layers across multiple devices.
99
100
Automatically places model components on specified devices and sets up
101
hooks for moving tensors between devices during forward pass.
102
103
Parameters:
104
- model: Model to dispatch across devices
105
- device_map: Mapping of layer names to devices
106
- main_device: Primary device for model execution
107
- state_dict: Optional state dict to load
108
- strict: Whether to strictly enforce device mapping
109
- preload_module_classes: Module classes to preload
110
111
Returns:
112
Model with device dispatch hooks configured
113
"""
114
115
def infer_auto_device_map(
116
model: torch.nn.Module,
117
max_memory: dict[int | str, int | str] | None = None,
118
no_split_module_classes: list[str] | None = None,
119
dtype: torch.dtype | str | None = None,
120
special_dtypes: dict[str, torch.dtype | str] | None = None,
121
verbose: bool = False
122
):
123
"""
124
Automatically infer optimal device mapping for model.
125
126
Analyzes model architecture and memory constraints to determine
127
the best placement of layers across available devices.
128
129
Parameters:
130
- model: Model to analyze
131
- max_memory: Maximum memory per device (dict of device_id: memory)
132
- no_split_module_classes: Module types that shouldn't be split
133
- dtype: Data type for memory calculation
134
- special_dtypes: Special data types for specific parameters
135
- verbose: Whether to print mapping details
136
137
Returns:
138
Dict mapping layer names to optimal devices
139
"""
140
```
141
142
### Efficient Initialization
143
144
Functions for memory-efficient model initialization, particularly useful for large models.
145
146
```python { .api }
147
def init_empty_weights(include_buffers: bool = None):
148
"""
149
Context manager for initializing models with empty tensors.
150
151
Creates model structure without allocating memory for parameters,
152
enabling initialization of models larger than available memory.
153
154
Parameters:
155
- include_buffers: Whether to initialize buffers as empty too
156
157
Returns:
158
Context manager for empty weight initialization
159
"""
160
161
def init_on_device(
162
device: torch.device | str,
163
include_buffers: bool = False
164
):
165
"""
166
Context manager to initialize model directly on specified device.
167
168
Avoids creating tensors on CPU first, reducing memory usage and
169
improving initialization speed for large models.
170
171
Parameters:
172
- device: Target device for initialization
173
- include_buffers: Whether to initialize buffers on device
174
175
Returns:
176
Context manager for device-specific initialization
177
"""
178
```
179
180
### Checkpoint Loading and Management
181
182
Functions for loading and managing model checkpoints with device mapping support.
183
184
```python { .api }
185
def load_checkpoint_and_dispatch(
186
model: torch.nn.Module,
187
checkpoint: str | os.PathLike,
188
device_map: dict[str, torch.device | str | int] | None = None,
189
max_memory: dict[int | str, int | str] | None = None,
190
no_split_module_classes: list[str] | None = None,
191
dtype: torch.dtype | str | None = None,
192
offload_folder: str | os.PathLike | None = None,
193
offload_state_dict: bool = False,
194
strict: bool = False
195
):
196
"""
197
Load checkpoint and dispatch model across devices.
198
199
Combines checkpoint loading with automatic device mapping and
200
offloading for large models that exceed device memory.
201
202
Parameters:
203
- model: Model to load checkpoint into
204
- checkpoint: Path to checkpoint file
205
- device_map: Manual device mapping (optional)
206
- max_memory: Memory constraints per device
207
- no_split_module_classes: Modules that shouldn't be split
208
- dtype: Data type for parameters
209
- offload_folder: Directory for offloaded parameters
210
- offload_state_dict: Whether to offload full state dict
211
- strict: Strict checkpoint loading
212
213
Returns:
214
Model with loaded checkpoint and device mapping applied
215
"""
216
217
def load_checkpoint_in_model(
218
model: torch.nn.Module,
219
checkpoint: str | os.PathLike,
220
device_map: dict[str, torch.device | str | int] | None = None,
221
offload_folder: str | os.PathLike | None = None,
222
dtype: torch.dtype | str | None = None,
223
offload_state_dict: bool = False,
224
offload_buffers: bool = False,
225
keep_in_fp32_modules: list[str] | None = None,
226
strict: bool = False
227
):
228
"""
229
Load checkpoint into model with advanced offloading options.
230
231
Provides fine-grained control over checkpoint loading with support
232
for mixed precision, selective offloading, and memory optimization.
233
234
Parameters:
235
- model: Target model
236
- checkpoint: Checkpoint path
237
- device_map: Device placement mapping
238
- offload_folder: Offloading directory
239
- dtype: Target data type
240
- offload_state_dict: Offload entire state dict
241
- offload_buffers: Offload buffer tensors
242
- keep_in_fp32_modules: Modules to keep in FP32
243
- strict: Strict loading mode
244
245
Returns:
246
List of missing and unexpected keys from checkpoint
247
"""
248
```
249
250
## Usage Examples
251
252
### Basic CPU Offloading
253
254
```python
255
from accelerate import cpu_offload
256
import torch
257
import torch.nn as nn
258
259
# Create large model
260
model = nn.Sequential(
261
nn.Linear(10000, 10000),
262
nn.ReLU(),
263
nn.Linear(10000, 1000)
264
)
265
266
# Enable CPU offloading - model parameters move to CPU when not in use
267
model = cpu_offload(model, execution_device="cuda:0")
268
269
# Model automatically moves parameters to GPU during forward pass
270
with torch.no_grad():
271
output = model(torch.randn(32, 10000, device="cuda:0"))
272
```
273
274
### Automatic Device Mapping
275
276
```python
277
from accelerate import infer_auto_device_map, dispatch_model
278
import torch
279
280
# Define memory constraints (in bytes or human-readable format)
281
max_memory = {
282
0: "10GB", # GPU 0 has 10GB available
283
1: "10GB", # GPU 1 has 10GB available
284
"cpu": "50GB" # 50GB CPU memory available
285
}
286
287
# Automatically determine optimal device placement
288
device_map = infer_auto_device_map(
289
model,
290
max_memory=max_memory,
291
no_split_module_classes=["LlamaDecoderLayer", "GPT2Block"]
292
)
293
294
# Apply the device mapping
295
model = dispatch_model(model, device_map=device_map)
296
```
297
298
### Memory-Efficient Initialization
299
300
```python
301
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
302
303
# Initialize model without allocating memory
304
with init_empty_weights():
305
model = MyLargeModel(config)
306
307
# Load checkpoint with automatic device mapping
308
model = load_checkpoint_and_dispatch(
309
model,
310
checkpoint="path/to/checkpoint.bin",
311
device_map="auto",
312
max_memory={"0": "15GB", "1": "15GB", "cpu": "50GB"},
313
offload_folder="./offload_weights"
314
)
315
```
316
317
### Disk Offloading for Extremely Large Models
318
319
```python
320
from accelerate import disk_offload
321
import tempfile
322
323
# Create temporary directory for offloaded weights
324
with tempfile.TemporaryDirectory() as temp_dir:
325
# Offload model to disk - enables models larger than total RAM
326
model = disk_offload(
327
model,
328
offload_dir=temp_dir,
329
execution_device="cuda:0"
330
)
331
332
# Model parameters are loaded from disk on-demand
333
output = model(input_tensor)
334
```
335
336
### Loading Large Model Checkpoints
337
338
```python
339
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
340
341
# Initialize model structure without memory allocation
342
with init_empty_weights(include_buffers=True):
343
model = AutoModel.from_pretrained(
344
"microsoft/DialoGPT-large",
345
torch_dtype=torch.float16
346
)
347
348
# Load and dispatch with automatic device mapping
349
model = load_checkpoint_and_dispatch(
350
model,
351
"path/to/sharded/checkpoint",
352
device_map="auto",
353
max_memory={"0": "12GB", "cpu": "30GB"},
354
dtype=torch.float16,
355
offload_folder="./offload"
356
)
357
```