or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

auto-classes.mdbase-classes.mdbert-models.mdfile-utilities.mdgpt2-models.mdindex.mdoptimization.mdother-models.md

optimization.mddocs/

0

# Optimization

1

2

Specialized optimizers and learning rate schedulers designed for transformer model training and fine-tuning. These optimization tools implement best practices for training large language models with proper weight decay, warmup schedules, and learning rate decay patterns.

3

4

## Capabilities

5

6

### AdamW Optimizer

7

8

Adam optimizer with weight decay fix, specifically designed for transformer models. Unlike standard Adam with L2 regularization, AdamW applies weight decay directly to the parameters.

9

10

```python { .api }

11

class AdamW:

12

def __init__(

13

self,

14

params,

15

lr=1e-3,

16

betas=(0.9, 0.999),

17

eps=1e-8,

18

weight_decay=0.01,

19

correct_bias=True

20

):

21

"""

22

Initialize AdamW optimizer.

23

24

Parameters:

25

- params: Iterable of parameters to optimize

26

- lr (float): Learning rate

27

- betas (Tuple[float, float]): Coefficients for gradient and squared gradient moving averages

28

- eps (float): Term added to denominator for numerical stability

29

- weight_decay (float): Weight decay coefficient

30

- correct_bias (bool): Whether to correct bias in moment estimates

31

"""

32

33

def step(self, closure=None):

34

"""

35

Perform a single optimization step.

36

37

Parameters:

38

- closure (callable, optional): Closure that reevaluates model and returns loss

39

40

Returns:

41

float: Loss value if closure is provided

42

"""

43

44

def zero_grad(self):

45

"""

46

Clear gradients of all optimized parameters.

47

"""

48

```

49

50

**Usage Example:**

51

52

```python

53

from pytorch_transformers import AdamW, BertForSequenceClassification

54

import torch

55

56

# Load model

57

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

58

59

# Initialize optimizer

60

optimizer = AdamW(

61

model.parameters(),

62

lr=2e-5,

63

weight_decay=0.01,

64

correct_bias=False

65

)

66

67

# Training step

68

inputs = torch.randint(0, 1000, (8, 128)) # Dummy input

69

labels = torch.randint(0, 2, (8,)) # Dummy labels

70

71

optimizer.zero_grad()

72

outputs = model(inputs, labels=labels)

73

loss = outputs.loss

74

loss.backward()

75

optimizer.step()

76

77

print(f"Loss: {loss.item():.4f}")

78

```

79

80

### Learning Rate Schedulers

81

82

Various learning rate scheduling strategies commonly used in transformer training, including warmup phases and different decay patterns.

83

84

#### ConstantLRSchedule

85

86

Maintains a constant learning rate throughout training.

87

88

```python { .api }

89

def ConstantLRSchedule(optimizer, last_epoch=-1):

90

"""

91

Create a constant learning rate schedule.

92

93

Parameters:

94

- optimizer: Wrapped optimizer

95

- last_epoch (int): Index of last epoch

96

97

Returns:

98

LambdaLR: Learning rate scheduler

99

"""

100

```

101

102

#### WarmupConstantSchedule

103

104

Linear warmup followed by constant learning rate.

105

106

```python { .api }

107

def WarmupConstantSchedule(optimizer, warmup_steps, last_epoch=-1):

108

"""

109

Create a schedule with linear warmup followed by constant learning rate.

110

111

Parameters:

112

- optimizer: Wrapped optimizer

113

- warmup_steps (int): Number of warmup steps

114

- last_epoch (int): Index of last epoch

115

116

Returns:

117

LambdaLR: Learning rate scheduler

118

"""

119

```

120

121

#### WarmupLinearSchedule

122

123

Linear warmup followed by linear decay to zero.

124

125

```python { .api }

126

def WarmupLinearSchedule(optimizer, warmup_steps, t_total, last_epoch=-1):

127

"""

128

Create a schedule with linear warmup followed by linear decay.

129

130

Parameters:

131

- optimizer: Wrapped optimizer

132

- warmup_steps (int): Number of warmup steps

133

- t_total (int): Total number of training steps

134

- last_epoch (int): Index of last epoch

135

136

Returns:

137

LambdaLR: Learning rate scheduler

138

"""

139

```

140

141

#### WarmupCosineSchedule

142

143

Linear warmup followed by cosine annealing decay.

144

145

```python { .api }

146

def WarmupCosineSchedule(optimizer, warmup_steps, t_total, cycles=0.5, last_epoch=-1):

147

"""

148

Create a schedule with linear warmup followed by cosine annealing.

149

150

Parameters:

151

- optimizer: Wrapped optimizer

152

- warmup_steps (int): Number of warmup steps

153

- t_total (int): Total number of training steps

154

- cycles (float): Number of cosine cycles (0.5 for half cosine)

155

- last_epoch (int): Index of last epoch

156

157

Returns:

158

LambdaLR: Learning rate scheduler

159

"""

160

```

161

162

#### WarmupCosineWithHardRestartsSchedule

163

164

Linear warmup followed by cosine annealing with hard restarts.

165

166

```python { .api }

167

def WarmupCosineWithHardRestartsSchedule(optimizer, warmup_steps, t_total, cycles=1.0, last_epoch=-1):

168

"""

169

Create a schedule with linear warmup followed by cosine annealing with hard restarts.

170

171

Parameters:

172

- optimizer: Wrapped optimizer

173

- warmup_steps (int): Number of warmup steps

174

- t_total (int): Total number of training steps

175

- cycles (float): Number of restart cycles

176

- last_epoch (int): Index of last epoch

177

178

Returns:

179

LambdaLR: Learning rate scheduler

180

"""

181

```

182

183

**Usage Examples:**

184

185

```python

186

from pytorch_transformers import (

187

AdamW,

188

WarmupLinearSchedule,

189

WarmupCosineSchedule,

190

WarmupConstantSchedule

191

)

192

193

# Setup model and optimizer

194

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

195

optimizer = AdamW(model.parameters(), lr=2e-5)

196

197

# Training configuration

198

num_epochs = 3

199

num_training_steps = 1000

200

warmup_steps = 100

201

202

# Linear schedule with warmup

203

linear_scheduler = WarmupLinearSchedule(

204

optimizer,

205

warmup_steps=warmup_steps,

206

t_total=num_training_steps

207

)

208

209

# Cosine schedule with warmup

210

cosine_scheduler = WarmupCosineSchedule(

211

optimizer,

212

warmup_steps=warmup_steps,

213

t_total=num_training_steps,

214

cycles=0.5

215

)

216

217

# Constant schedule with warmup

218

constant_scheduler = WarmupConstantSchedule(

219

optimizer,

220

warmup_steps=warmup_steps

221

)

222

223

# Training loop example

224

for epoch in range(num_epochs):

225

for step in range(num_training_steps // num_epochs):

226

# Training step

227

optimizer.zero_grad()

228

# ... forward pass, loss calculation, backward pass ...

229

optimizer.step()

230

linear_scheduler.step() # Update learning rate

231

232

# Log current learning rate

233

current_lr = optimizer.param_groups[0]['lr']

234

if step % 100 == 0:

235

print(f"Epoch {epoch}, Step {step}, LR: {current_lr:.2e}")

236

```

237

238

## Optimization Best Practices

239

240

### Learning Rate Selection

241

242

**Fine-tuning Pre-trained Models:**

243

- BERT/RoBERTa: 2e-5, 3e-5, 5e-5

244

- GPT-2: 1e-4, 2e-4, 5e-4

245

- Smaller models: Higher learning rates (up to 1e-3)

246

247

**Warmup Steps:**

248

- Typically 10% of total training steps

249

- For short training: 500-1000 steps

250

- For long training: 5000-10000 steps

251

252

```python

253

# Recommended setup for BERT fine-tuning

254

total_steps = len(train_dataloader) * num_epochs

255

warmup_steps = int(0.1 * total_steps)

256

257

optimizer = AdamW(

258

model.parameters(),

259

lr=2e-5,

260

weight_decay=0.01,

261

correct_bias=False

262

)

263

264

scheduler = WarmupLinearSchedule(

265

optimizer,

266

warmup_steps=warmup_steps,

267

t_total=total_steps

268

)

269

```

270

271

### Weight Decay Configuration

272

273

**Recommended weight decay values:**

274

- Default: 0.01

275

- Larger models: 0.1

276

- Smaller models: 0.001

277

278

**Parameter groups with different weight decay:**

279

280

```python

281

# Apply weight decay only to weights, not biases or layer norms

282

no_decay = ["bias", "LayerNorm.weight"]

283

optimizer_grouped_parameters = [

284

{

285

"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],

286

"weight_decay": 0.01,

287

},

288

{

289

"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],

290

"weight_decay": 0.0,

291

},

292

]

293

294

optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)

295

```

296

297

### Gradient Clipping

298

299

```python

300

import torch.nn.utils as nn_utils

301

302

# Training step with gradient clipping

303

optimizer.zero_grad()

304

loss.backward()

305

306

# Clip gradients to prevent exploding gradients

307

nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

308

309

optimizer.step()

310

scheduler.step()

311

```

312

313

### Mixed Precision Training

314

315

```python

316

from torch.cuda.amp import autocast, GradScaler

317

318

# Initialize gradient scaler for mixed precision

319

scaler = GradScaler()

320

321

# Training step with mixed precision

322

optimizer.zero_grad()

323

324

with autocast():

325

outputs = model(**inputs)

326

loss = outputs.loss

327

328

# Scale loss and backward pass

329

scaler.scale(loss).backward()

330

scaler.step(optimizer)

331

scaler.update()

332

scheduler.step()

333

```

334

335

## Schedule Visualization

336

337

Different learning rate schedules behave differently during training:

338

339

**Linear Schedule**: Steady decrease after warmup

340

- Best for: Most fine-tuning tasks

341

- Characteristics: Predictable, stable convergence

342

343

**Cosine Schedule**: Smooth decay following cosine curve

344

- Best for: Long training runs, better final performance

345

- Characteristics: Slower initial decay, faster final decay

346

347

**Constant Schedule**: Maintains rate after warmup

348

- Best for: Continued pre-training, domain adaptation

349

- Characteristics: No decay, constant exploration

350

351

**Cosine with Restarts**: Periodic learning rate increases

352

- Best for: Finding better local minima, avoiding plateaus

353

- Characteristics: Multiple convergence opportunities