0
# Utilities and State Management
1
2
Essential utility functions for managing PEFT model state, loading/saving adapters, preparing models for training, and handling various integration scenarios. These functions provide the foundational operations for PEFT workflows.
3
4
## Capabilities
5
6
### State Dictionary Management
7
8
Functions for extracting, setting, and managing PEFT model state dictionaries.
9
10
```python { .api }
11
def get_peft_model_state_dict(
12
model,
13
state_dict: Optional[dict] = None,
14
adapter_name: str = "default"
15
) -> dict:
16
"""
17
Get the state dictionary of PEFT model parameters.
18
19
Args:
20
model: PEFT model instance
21
state_dict: Optional state dict to filter, if None uses model.state_dict()
22
adapter_name: Name of the adapter to get state dict for
23
24
Returns:
25
Dictionary containing only PEFT parameters
26
"""
27
28
def set_peft_model_state_dict(
29
model,
30
peft_model_state_dict: dict,
31
adapter_name: str = "default"
32
):
33
"""
34
Set the state dictionary of PEFT model parameters.
35
36
Args:
37
model: PEFT model instance
38
peft_model_state_dict: State dictionary containing PEFT parameters
39
adapter_name: Name of the adapter to set state dict for
40
"""
41
42
def load_peft_weights(model_id: str, device: Optional[str] = None) -> dict:
43
"""
44
Load PEFT weights from a model identifier or path.
45
46
Args:
47
model_id: Model identifier or local path
48
device: Device to load weights on
49
50
Returns:
51
Dictionary containing loaded PEFT weights
52
"""
53
```
54
55
### Model Preparation and Training Utilities
56
57
Functions for preparing models for efficient training, especially with quantization.
58
59
```python { .api }
60
def prepare_model_for_kbit_training(
61
model,
62
use_gradient_checkpointing: bool = True,
63
gradient_checkpointing_kwargs: Optional[dict] = None
64
):
65
"""
66
Prepare model for k-bit training by enabling gradient computation for input embeddings.
67
68
Args:
69
model: Model to prepare for training
70
use_gradient_checkpointing: Whether to enable gradient checkpointing
71
gradient_checkpointing_kwargs: Additional arguments for gradient checkpointing
72
73
Returns:
74
Prepared model ready for k-bit training
75
"""
76
77
def cast_mixed_precision_params(
78
model,
79
dtype: torch.dtype = torch.float16
80
):
81
"""
82
Cast mixed precision parameters to specified dtype.
83
84
Args:
85
model: Model to cast parameters for
86
dtype: Target dtype for parameters
87
"""
88
```
89
90
### Configuration and Mapping Utilities
91
92
Functions for working with PEFT configurations and model mappings.
93
94
```python { .api }
95
def get_peft_config(config_dict: dict) -> PeftConfig:
96
"""
97
Get PEFT configuration from dictionary.
98
99
Args:
100
config_dict: Dictionary containing configuration parameters
101
102
Returns:
103
Appropriate PeftConfig instance
104
"""
105
106
def inject_adapter_in_model(
107
peft_config: PeftConfig,
108
model,
109
adapter_name: str = "default"
110
):
111
"""
112
Inject adapter into model based on PEFT configuration.
113
114
Args:
115
peft_config: PEFT configuration
116
model: Base model to inject adapter into
117
adapter_name: Name of the adapter
118
"""
119
```
120
121
### Preprocessing and Postprocessing
122
123
Utility functions for data preprocessing and model-specific postprocessing.
124
125
```python { .api }
126
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
127
"""
128
Shift input tokens to the right for sequence-to-sequence training.
129
130
Args:
131
input_ids: Input token IDs
132
pad_token_id: Padding token ID
133
decoder_start_token_id: Decoder start token ID
134
135
Returns:
136
Shifted token IDs
137
"""
138
139
def bloom_model_postprocess_past_key_value(past_key_values, batch_size: int, seq_len: int):
140
"""
141
Postprocess past key values for BLOOM models.
142
143
Args:
144
past_key_values: Past key value tensors
145
batch_size: Batch size
146
seq_len: Sequence length
147
148
Returns:
149
Postprocessed past key values
150
"""
151
```
152
153
### Integration Utilities
154
155
Functions for integrating with various frameworks and handling device management.
156
157
```python { .api }
158
def map_cache_to_layer_device_map(
159
cache,
160
layer_device_map: dict,
161
offload_dir: Optional[str] = None
162
):
163
"""
164
Map cache tensors to layer device map for distributed inference.
165
166
Args:
167
cache: Cache object to map
168
layer_device_map: Mapping of layers to devices
169
offload_dir: Directory for offloading tensors
170
171
Returns:
172
Mapped cache object
173
"""
174
```
175
176
### Target Module Mappings
177
178
Predefined mappings of model architectures to commonly used target modules for different PEFT methods.
179
180
```python { .api }
181
# LoRA target modules for different model architectures
182
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: dict = {
183
"t5": ["q", "v"],
184
"mt5": ["q", "v"],
185
"bart": ["q_proj", "v_proj"],
186
"gpt2": ["c_attn"],
187
"bloom": ["query_key_value"],
188
"blip-2": ["q", "v", "q_proj", "v_proj"],
189
"opt": ["q_proj", "v_proj"],
190
"gptj": ["q_proj", "v_proj"],
191
"gpt_neox": ["query_key_value"],
192
"gpt_neo": ["q_proj", "v_proj"],
193
"bert": ["query", "value"],
194
"roberta": ["query", "value"],
195
"xlm-roberta": ["query", "value"],
196
"electra": ["query", "value"],
197
"deberta-v2": ["query_proj", "value_proj"],
198
"deberta": ["in_proj"],
199
"layoutlm": ["query", "value"],
200
"llama": ["q_proj", "v_proj"],
201
"chatglm": ["query_key_value"],
202
"gpt_bigcode": ["c_attn"],
203
"mpt": ["Wqkv"],
204
"RefinedWebModel": ["query_key_value"],
205
"RefinedWeb": ["query_key_value"],
206
"falcon": ["query_key_value"],
207
"btlm": ["c_proj", "c_attn"],
208
"codegen": ["qkv_proj"],
209
"mistral": ["q_proj", "v_proj"],
210
"mixtral": ["q_proj", "v_proj"],
211
"stablelm": ["q_proj", "v_proj"],
212
"phi": ["q_proj", "v_proj", "fc1", "fc2"],
213
"gemma": ["q_proj", "v_proj"],
214
}
215
216
# AdaLoRA target modules
217
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: dict = {
218
"t5": ["q", "v"],
219
"mt5": ["q", "v"],
220
"bart": ["q_proj", "v_proj"],
221
"gpt2": ["c_attn"],
222
"bloom": ["query_key_value"],
223
"opt": ["q_proj", "v_proj"],
224
"gptj": ["q_proj", "v_proj"],
225
"gpt_neox": ["query_key_value"],
226
"gpt_neo": ["q_proj", "v_proj"],
227
"llama": ["q_proj", "v_proj"],
228
"bert": ["query", "value"],
229
"roberta": ["query", "value"],
230
}
231
232
# IA3 target modules and feedforward modules
233
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: dict = {
234
"t5": ["k", "v", "wo"],
235
"mt5": ["k", "v", "wo"],
236
"gpt2": ["c_attn", "mlp.c_proj"],
237
"bloom": ["query_key_value", "mlp.dense_4h_to_h"],
238
"opt": ["k_proj", "v_proj", "fc2"],
239
"gptj": ["k_proj", "v_proj", "fc_out"],
240
"gpt_neox": ["query_key_value", "dense_4h_to_h"],
241
"gpt_neo": ["k_proj", "v_proj", "c_proj"],
242
"bart": ["k_proj", "v_proj", "fc2"],
243
"llama": ["k_proj", "v_proj", "down_proj"],
244
}
245
246
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING: dict = {
247
"t5": ["wo"],
248
"mt5": ["wo"],
249
"gpt2": ["mlp.c_proj"],
250
"bloom": ["mlp.dense_4h_to_h"],
251
"opt": ["fc2"],
252
"gptj": ["fc_out"],
253
"gpt_neox": ["dense_4h_to_h"],
254
"gpt_neo": ["c_proj"],
255
"bart": ["fc2"],
256
"llama": ["down_proj"],
257
}
258
```
259
260
### Constants and Configuration Names
261
262
Important constants used throughout the PEFT library.
263
264
```python { .api }
265
CONFIG_NAME: str = "adapter_config.json"
266
WEIGHTS_NAME: str = "adapter_model.bin"
267
SAFETENSORS_WEIGHTS_NAME: str = "adapter_model.safetensors"
268
269
INCLUDE_LINEAR_LAYERS_SHORTHAND: List[str] = ["linear", "Linear"]
270
```
271
272
## Usage Examples
273
274
### Saving and Loading PEFT State
275
276
```python
277
from peft import get_peft_model_state_dict, set_peft_model_state_dict
278
import torch
279
280
# Get PEFT state dictionary
281
peft_state_dict = get_peft_model_state_dict(peft_model)
282
283
# Save to file
284
torch.save(peft_state_dict, "peft_weights.pt")
285
286
# Load and set state dictionary
287
loaded_state_dict = torch.load("peft_weights.pt")
288
set_peft_model_state_dict(peft_model, loaded_state_dict)
289
```
290
291
### Preparing Model for Quantized Training
292
293
```python
294
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
295
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
296
297
# Load quantized model
298
bnb_config = BitsAndBytesConfig(
299
load_in_4bit=True,
300
bnb_4bit_use_double_quant=True,
301
bnb_4bit_quant_type="nf4",
302
bnb_4bit_compute_dtype=torch.bfloat16
303
)
304
305
model = AutoModelForCausalLM.from_pretrained(
306
"microsoft/DialoGPT-medium",
307
quantization_config=bnb_config,
308
device_map="auto"
309
)
310
311
# Prepare for k-bit training
312
model = prepare_model_for_kbit_training(
313
model,
314
use_gradient_checkpointing=True
315
)
316
317
# Add PEFT adapter
318
peft_config = LoraConfig(
319
r=16,
320
lora_alpha=32,
321
target_modules=["c_attn", "c_proj"],
322
lora_dropout=0.1,
323
bias="none",
324
task_type="CAUSAL_LM"
325
)
326
327
peft_model = get_peft_model(model, peft_config)
328
```
329
330
### Working with Mixed Precision
331
332
```python
333
from peft import cast_mixed_precision_params
334
335
# Cast parameters to half precision
336
cast_mixed_precision_params(peft_model, torch.float16)
337
338
# Training loop with automatic mixed precision
339
from torch.cuda.amp import autocast, GradScaler
340
341
scaler = GradScaler()
342
343
for batch in dataloader:
344
optimizer.zero_grad()
345
346
with autocast():
347
outputs = peft_model(**batch)
348
loss = outputs.loss
349
350
scaler.scale(loss).backward()
351
scaler.step(optimizer)
352
scaler.update()
353
```
354
355
### Using Target Module Mappings
356
357
```python
358
from peft import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, LoraConfig
359
360
# Get recommended target modules for model architecture
361
model_type = model.config.model_type
362
target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model_type)
363
364
if target_modules:
365
lora_config = LoraConfig(
366
r=16,
367
lora_alpha=32,
368
target_modules=target_modules,
369
task_type="CAUSAL_LM"
370
)
371
else:
372
# Fallback to manual specification
373
lora_config = LoraConfig(
374
r=16,
375
lora_alpha=32,
376
target_modules=["q_proj", "v_proj"], # Manual specification
377
task_type="CAUSAL_LM"
378
)
379
```
380
381
### Handling Sequence-to-Sequence Tasks
382
383
```python
384
from peft import shift_tokens_right
385
386
# Prepare decoder input ids for seq2seq training
387
def prepare_decoder_input_ids_from_labels(labels, pad_token_id, decoder_start_token_id):
388
return shift_tokens_right(labels, pad_token_id, decoder_start_token_id)
389
390
# Example usage in training
391
labels = tokenizer("Target text", return_tensors="pt").input_ids
392
decoder_input_ids = prepare_decoder_input_ids_from_labels(
393
labels,
394
tokenizer.pad_token_id,
395
tokenizer.eos_token_id
396
)
397
398
outputs = peft_model(
399
input_ids=input_ids,
400
decoder_input_ids=decoder_input_ids,
401
labels=labels
402
)
403
loss = outputs.loss
404
```
405
406
### Loading Weights from Hub or Local Path
407
408
```python
409
from peft import load_peft_weights
410
411
# Load from Hugging Face Hub
412
weights = load_peft_weights("username/my-peft-adapter")
413
414
# Load from local path
415
weights = load_peft_weights("./local/peft/adapter")
416
417
# Load with specific device
418
weights = load_peft_weights("username/my-peft-adapter", device="cuda:0")
419
```