or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-methods.mdauto-classes.mdcore-models.mdindex.mdlora-methods.mdprompt-learning.mdutilities.md

utilities.mddocs/

0

# Utilities and State Management

1

2

Essential utility functions for managing PEFT model state, loading/saving adapters, preparing models for training, and handling various integration scenarios. These functions provide the foundational operations for PEFT workflows.

3

4

## Capabilities

5

6

### State Dictionary Management

7

8

Functions for extracting, setting, and managing PEFT model state dictionaries.

9

10

```python { .api }

11

def get_peft_model_state_dict(

12

model,

13

state_dict: Optional[dict] = None,

14

adapter_name: str = "default"

15

) -> dict:

16

"""

17

Get the state dictionary of PEFT model parameters.

18

19

Args:

20

model: PEFT model instance

21

state_dict: Optional state dict to filter, if None uses model.state_dict()

22

adapter_name: Name of the adapter to get state dict for

23

24

Returns:

25

Dictionary containing only PEFT parameters

26

"""

27

28

def set_peft_model_state_dict(

29

model,

30

peft_model_state_dict: dict,

31

adapter_name: str = "default"

32

):

33

"""

34

Set the state dictionary of PEFT model parameters.

35

36

Args:

37

model: PEFT model instance

38

peft_model_state_dict: State dictionary containing PEFT parameters

39

adapter_name: Name of the adapter to set state dict for

40

"""

41

42

def load_peft_weights(model_id: str, device: Optional[str] = None) -> dict:

43

"""

44

Load PEFT weights from a model identifier or path.

45

46

Args:

47

model_id: Model identifier or local path

48

device: Device to load weights on

49

50

Returns:

51

Dictionary containing loaded PEFT weights

52

"""

53

```

54

55

### Model Preparation and Training Utilities

56

57

Functions for preparing models for efficient training, especially with quantization.

58

59

```python { .api }

60

def prepare_model_for_kbit_training(

61

model,

62

use_gradient_checkpointing: bool = True,

63

gradient_checkpointing_kwargs: Optional[dict] = None

64

):

65

"""

66

Prepare model for k-bit training by enabling gradient computation for input embeddings.

67

68

Args:

69

model: Model to prepare for training

70

use_gradient_checkpointing: Whether to enable gradient checkpointing

71

gradient_checkpointing_kwargs: Additional arguments for gradient checkpointing

72

73

Returns:

74

Prepared model ready for k-bit training

75

"""

76

77

def cast_mixed_precision_params(

78

model,

79

dtype: torch.dtype = torch.float16

80

):

81

"""

82

Cast mixed precision parameters to specified dtype.

83

84

Args:

85

model: Model to cast parameters for

86

dtype: Target dtype for parameters

87

"""

88

```

89

90

### Configuration and Mapping Utilities

91

92

Functions for working with PEFT configurations and model mappings.

93

94

```python { .api }

95

def get_peft_config(config_dict: dict) -> PeftConfig:

96

"""

97

Get PEFT configuration from dictionary.

98

99

Args:

100

config_dict: Dictionary containing configuration parameters

101

102

Returns:

103

Appropriate PeftConfig instance

104

"""

105

106

def inject_adapter_in_model(

107

peft_config: PeftConfig,

108

model,

109

adapter_name: str = "default"

110

):

111

"""

112

Inject adapter into model based on PEFT configuration.

113

114

Args:

115

peft_config: PEFT configuration

116

model: Base model to inject adapter into

117

adapter_name: Name of the adapter

118

"""

119

```

120

121

### Preprocessing and Postprocessing

122

123

Utility functions for data preprocessing and model-specific postprocessing.

124

125

```python { .api }

126

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):

127

"""

128

Shift input tokens to the right for sequence-to-sequence training.

129

130

Args:

131

input_ids: Input token IDs

132

pad_token_id: Padding token ID

133

decoder_start_token_id: Decoder start token ID

134

135

Returns:

136

Shifted token IDs

137

"""

138

139

def bloom_model_postprocess_past_key_value(past_key_values, batch_size: int, seq_len: int):

140

"""

141

Postprocess past key values for BLOOM models.

142

143

Args:

144

past_key_values: Past key value tensors

145

batch_size: Batch size

146

seq_len: Sequence length

147

148

Returns:

149

Postprocessed past key values

150

"""

151

```

152

153

### Integration Utilities

154

155

Functions for integrating with various frameworks and handling device management.

156

157

```python { .api }

158

def map_cache_to_layer_device_map(

159

cache,

160

layer_device_map: dict,

161

offload_dir: Optional[str] = None

162

):

163

"""

164

Map cache tensors to layer device map for distributed inference.

165

166

Args:

167

cache: Cache object to map

168

layer_device_map: Mapping of layers to devices

169

offload_dir: Directory for offloading tensors

170

171

Returns:

172

Mapped cache object

173

"""

174

```

175

176

### Target Module Mappings

177

178

Predefined mappings of model architectures to commonly used target modules for different PEFT methods.

179

180

```python { .api }

181

# LoRA target modules for different model architectures

182

TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: dict = {

183

"t5": ["q", "v"],

184

"mt5": ["q", "v"],

185

"bart": ["q_proj", "v_proj"],

186

"gpt2": ["c_attn"],

187

"bloom": ["query_key_value"],

188

"blip-2": ["q", "v", "q_proj", "v_proj"],

189

"opt": ["q_proj", "v_proj"],

190

"gptj": ["q_proj", "v_proj"],

191

"gpt_neox": ["query_key_value"],

192

"gpt_neo": ["q_proj", "v_proj"],

193

"bert": ["query", "value"],

194

"roberta": ["query", "value"],

195

"xlm-roberta": ["query", "value"],

196

"electra": ["query", "value"],

197

"deberta-v2": ["query_proj", "value_proj"],

198

"deberta": ["in_proj"],

199

"layoutlm": ["query", "value"],

200

"llama": ["q_proj", "v_proj"],

201

"chatglm": ["query_key_value"],

202

"gpt_bigcode": ["c_attn"],

203

"mpt": ["Wqkv"],

204

"RefinedWebModel": ["query_key_value"],

205

"RefinedWeb": ["query_key_value"],

206

"falcon": ["query_key_value"],

207

"btlm": ["c_proj", "c_attn"],

208

"codegen": ["qkv_proj"],

209

"mistral": ["q_proj", "v_proj"],

210

"mixtral": ["q_proj", "v_proj"],

211

"stablelm": ["q_proj", "v_proj"],

212

"phi": ["q_proj", "v_proj", "fc1", "fc2"],

213

"gemma": ["q_proj", "v_proj"],

214

}

215

216

# AdaLoRA target modules

217

TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: dict = {

218

"t5": ["q", "v"],

219

"mt5": ["q", "v"],

220

"bart": ["q_proj", "v_proj"],

221

"gpt2": ["c_attn"],

222

"bloom": ["query_key_value"],

223

"opt": ["q_proj", "v_proj"],

224

"gptj": ["q_proj", "v_proj"],

225

"gpt_neox": ["query_key_value"],

226

"gpt_neo": ["q_proj", "v_proj"],

227

"llama": ["q_proj", "v_proj"],

228

"bert": ["query", "value"],

229

"roberta": ["query", "value"],

230

}

231

232

# IA3 target modules and feedforward modules

233

TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: dict = {

234

"t5": ["k", "v", "wo"],

235

"mt5": ["k", "v", "wo"],

236

"gpt2": ["c_attn", "mlp.c_proj"],

237

"bloom": ["query_key_value", "mlp.dense_4h_to_h"],

238

"opt": ["k_proj", "v_proj", "fc2"],

239

"gptj": ["k_proj", "v_proj", "fc_out"],

240

"gpt_neox": ["query_key_value", "dense_4h_to_h"],

241

"gpt_neo": ["k_proj", "v_proj", "c_proj"],

242

"bart": ["k_proj", "v_proj", "fc2"],

243

"llama": ["k_proj", "v_proj", "down_proj"],

244

}

245

246

TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING: dict = {

247

"t5": ["wo"],

248

"mt5": ["wo"],

249

"gpt2": ["mlp.c_proj"],

250

"bloom": ["mlp.dense_4h_to_h"],

251

"opt": ["fc2"],

252

"gptj": ["fc_out"],

253

"gpt_neox": ["dense_4h_to_h"],

254

"gpt_neo": ["c_proj"],

255

"bart": ["fc2"],

256

"llama": ["down_proj"],

257

}

258

```

259

260

### Constants and Configuration Names

261

262

Important constants used throughout the PEFT library.

263

264

```python { .api }

265

CONFIG_NAME: str = "adapter_config.json"

266

WEIGHTS_NAME: str = "adapter_model.bin"

267

SAFETENSORS_WEIGHTS_NAME: str = "adapter_model.safetensors"

268

269

INCLUDE_LINEAR_LAYERS_SHORTHAND: List[str] = ["linear", "Linear"]

270

```

271

272

## Usage Examples

273

274

### Saving and Loading PEFT State

275

276

```python

277

from peft import get_peft_model_state_dict, set_peft_model_state_dict

278

import torch

279

280

# Get PEFT state dictionary

281

peft_state_dict = get_peft_model_state_dict(peft_model)

282

283

# Save to file

284

torch.save(peft_state_dict, "peft_weights.pt")

285

286

# Load and set state dictionary

287

loaded_state_dict = torch.load("peft_weights.pt")

288

set_peft_model_state_dict(peft_model, loaded_state_dict)

289

```

290

291

### Preparing Model for Quantized Training

292

293

```python

294

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

295

from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

296

297

# Load quantized model

298

bnb_config = BitsAndBytesConfig(

299

load_in_4bit=True,

300

bnb_4bit_use_double_quant=True,

301

bnb_4bit_quant_type="nf4",

302

bnb_4bit_compute_dtype=torch.bfloat16

303

)

304

305

model = AutoModelForCausalLM.from_pretrained(

306

"microsoft/DialoGPT-medium",

307

quantization_config=bnb_config,

308

device_map="auto"

309

)

310

311

# Prepare for k-bit training

312

model = prepare_model_for_kbit_training(

313

model,

314

use_gradient_checkpointing=True

315

)

316

317

# Add PEFT adapter

318

peft_config = LoraConfig(

319

r=16,

320

lora_alpha=32,

321

target_modules=["c_attn", "c_proj"],

322

lora_dropout=0.1,

323

bias="none",

324

task_type="CAUSAL_LM"

325

)

326

327

peft_model = get_peft_model(model, peft_config)

328

```

329

330

### Working with Mixed Precision

331

332

```python

333

from peft import cast_mixed_precision_params

334

335

# Cast parameters to half precision

336

cast_mixed_precision_params(peft_model, torch.float16)

337

338

# Training loop with automatic mixed precision

339

from torch.cuda.amp import autocast, GradScaler

340

341

scaler = GradScaler()

342

343

for batch in dataloader:

344

optimizer.zero_grad()

345

346

with autocast():

347

outputs = peft_model(**batch)

348

loss = outputs.loss

349

350

scaler.scale(loss).backward()

351

scaler.step(optimizer)

352

scaler.update()

353

```

354

355

### Using Target Module Mappings

356

357

```python

358

from peft import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, LoraConfig

359

360

# Get recommended target modules for model architecture

361

model_type = model.config.model_type

362

target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model_type)

363

364

if target_modules:

365

lora_config = LoraConfig(

366

r=16,

367

lora_alpha=32,

368

target_modules=target_modules,

369

task_type="CAUSAL_LM"

370

)

371

else:

372

# Fallback to manual specification

373

lora_config = LoraConfig(

374

r=16,

375

lora_alpha=32,

376

target_modules=["q_proj", "v_proj"], # Manual specification

377

task_type="CAUSAL_LM"

378

)

379

```

380

381

### Handling Sequence-to-Sequence Tasks

382

383

```python

384

from peft import shift_tokens_right

385

386

# Prepare decoder input ids for seq2seq training

387

def prepare_decoder_input_ids_from_labels(labels, pad_token_id, decoder_start_token_id):

388

return shift_tokens_right(labels, pad_token_id, decoder_start_token_id)

389

390

# Example usage in training

391

labels = tokenizer("Target text", return_tensors="pt").input_ids

392

decoder_input_ids = prepare_decoder_input_ids_from_labels(

393

labels,

394

tokenizer.pad_token_id,

395

tokenizer.eos_token_id

396

)

397

398

outputs = peft_model(

399

input_ids=input_ids,

400

decoder_input_ids=decoder_input_ids,

401

labels=labels

402

)

403

loss = outputs.loss

404

```

405

406

### Loading Weights from Hub or Local Path

407

408

```python

409

from peft import load_peft_weights

410

411

# Load from Hugging Face Hub

412

weights = load_peft_weights("username/my-peft-adapter")

413

414

# Load from local path

415

weights = load_peft_weights("./local/peft/adapter")

416

417

# Load with specific device

418

weights = load_peft_weights("username/my-peft-adapter", device="cuda:0")

419

```