or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

schedules.mddocs/

0

# Learning Rate Schedules

1

2

Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules. These schedules help optimize training dynamics and achieve better convergence.

3

4

## Capabilities

5

6

### Basic Schedules

7

8

#### Constant Schedule

9

10

```python { .api }

11

def constant_schedule(value):

12

"""

13

Constant value schedule.

14

15

Args:

16

value: Constant value to return

17

18

Returns:

19

Schedule function

20

"""

21

```

22

23

#### Linear Schedule

24

25

```python { .api }

26

def linear_schedule(init_value, end_value, transition_steps):

27

"""

28

Linear interpolation between two values.

29

30

Args:

31

init_value: Initial value

32

end_value: Final value

33

transition_steps: Number of steps for transition

34

35

Returns:

36

Schedule function

37

"""

38

```

39

40

#### Polynomial Schedule

41

42

```python { .api }

43

def polynomial_schedule(init_value, end_value, power, transition_steps):

44

"""

45

Polynomial decay schedule.

46

47

Args:

48

init_value: Initial value

49

end_value: Final value

50

power: Polynomial power (1.0 = linear, 2.0 = quadratic, etc.)

51

transition_steps: Number of steps for transition

52

53

Returns:

54

Schedule function

55

"""

56

```

57

58

### Exponential Decay

59

60

```python { .api }

61

def exponential_decay(init_value, decay_rate, transition_steps, transition_begin=0, staircase=False, end_value=None):

62

"""

63

Exponential decay schedule.

64

65

Args:

66

init_value: Initial value

67

decay_rate: Decay rate (e.g., 0.96 for 4% decay)

68

transition_steps: Steps between decay applications

69

transition_begin: Step to begin decay (default: 0)

70

staircase: Whether to apply decay in discrete steps (default: False)

71

end_value: Minimum value to decay to (default: None)

72

73

Returns:

74

Schedule function

75

"""

76

```

77

78

### Cosine Schedules

79

80

#### Cosine Decay

81

82

```python { .api }

83

def cosine_decay_schedule(init_value, decay_steps, alpha=0.0):

84

"""

85

Cosine decay schedule.

86

87

Args:

88

init_value: Initial value

89

decay_steps: Number of steps for full cosine cycle

90

alpha: Minimum value as fraction of init_value (default: 0.0)

91

92

Returns:

93

Schedule function

94

"""

95

```

96

97

#### Cosine One-Cycle

98

99

```python { .api }

100

def cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):

101

"""

102

One-cycle cosine schedule (warmup, decay, final decay).

103

104

Args:

105

transition_steps: Total number of steps

106

peak_value: Maximum value at peak

107

pct_start: Percentage of steps for warmup phase (default: 0.3)

108

pct_final: Percentage of steps before final decay (default: 0.85)

109

final_div_factor: Final value divisor (default: 1e4)

110

111

Returns:

112

Schedule function

113

"""

114

```

115

116

### Piecewise Schedules

117

118

#### Piecewise Constant

119

120

```python { .api }

121

def piecewise_constant_schedule(boundaries_and_scales):

122

"""

123

Piecewise constant schedule with different values in different intervals.

124

125

Args:

126

boundaries_and_scales: Dict mapping step boundaries to scale factors

127

128

Returns:

129

Schedule function

130

"""

131

```

132

133

#### Piecewise Interpolate

134

135

```python { .api }

136

def piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales):

137

"""

138

Piecewise schedule with interpolation between boundaries.

139

140

Args:

141

interpolate_type: Type of interpolation ('linear', 'cosine')

142

init_value: Initial value

143

boundaries_and_scales: Dict mapping boundaries to scale factors

144

145

Returns:

146

Schedule function

147

"""

148

```

149

150

### Warmup Schedules

151

152

#### Warmup + Constant

153

154

```python { .api }

155

def warmup_constant_schedule(init_value, peak_value, warmup_steps):

156

"""

157

Linear warmup followed by constant value.

158

159

Args:

160

init_value: Initial value during warmup

161

peak_value: Constant value after warmup

162

warmup_steps: Number of warmup steps

163

164

Returns:

165

Schedule function

166

"""

167

```

168

169

#### Warmup + Cosine Decay

170

171

```python { .api }

172

def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0):

173

"""

174

Linear warmup followed by cosine decay.

175

176

Args:

177

init_value: Initial value during warmup

178

peak_value: Peak value after warmup

179

warmup_steps: Number of warmup steps

180

decay_steps: Number of decay steps after warmup

181

end_value: Final value after decay (default: 0.0)

182

183

Returns:

184

Schedule function

185

"""

186

```

187

188

#### Warmup + Exponential Decay

189

190

```python { .api }

191

def warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None):

192

"""

193

Linear warmup followed by exponential decay.

194

195

Args:

196

init_value: Initial value during warmup

197

peak_value: Peak value after warmup

198

warmup_steps: Number of warmup steps

199

transition_steps: Steps between decay applications

200

decay_rate: Exponential decay rate

201

transition_begin: Step to begin decay (default: 0)

202

staircase: Whether to apply decay in discrete steps (default: False)

203

end_value: Minimum decay value (default: None)

204

205

Returns:

206

Schedule function

207

"""

208

```

209

210

### Advanced Schedules

211

212

#### Linear One-Cycle

213

214

```python { .api }

215

def linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, final_div_factor=1e4):

216

"""

217

One-cycle linear schedule (warmup, decay, final decay).

218

219

Args:

220

transition_steps: Total number of steps

221

peak_value: Maximum value at peak

222

pct_start: Percentage of steps for warmup phase (default: 0.3)

223

pct_final: Percentage of steps before final decay (default: 0.85)

224

final_div_factor: Final value divisor (default: 1e4)

225

226

Returns:

227

Schedule function

228

"""

229

```

230

231

#### SGDR Schedule

232

233

```python { .api }

234

def sgdr_schedule(cosine_decay_schedule, restart_period, t_mult=1.0):

235

"""

236

Stochastic Gradient Descent with Restarts (SGDR) schedule.

237

238

Args:

239

cosine_decay_schedule: Base cosine decay schedule

240

restart_period: Initial restart period

241

t_mult: Multiplier for restart period (default: 1.0)

242

243

Returns:

244

Schedule function

245

"""

246

```

247

248

### Schedule Composition

249

250

#### Join Schedules

251

252

```python { .api }

253

def join_schedules(schedules, boundaries):

254

"""

255

Join multiple schedules at specified boundaries.

256

257

Args:

258

schedules: List of schedule functions

259

boundaries: List of step boundaries for schedule transitions

260

261

Returns:

262

Combined schedule function

263

"""

264

```

265

266

### Hyperparameter Injection

267

268

#### Static Hyperparameters

269

270

```python { .api }

271

def inject_hyperparams(transformation, **scheduled_hyperparams):

272

"""

273

Inject scheduled hyperparameters into transformation.

274

275

Args:

276

transformation: Base gradient transformation

277

**scheduled_hyperparams: Named schedule functions for hyperparameters

278

279

Returns:

280

GradientTransformation with scheduled hyperparameters

281

"""

282

```

283

284

#### Stateful Hyperparameters

285

286

```python { .api }

287

def inject_stateful_hyperparams(transformation, **scheduled_hyperparams):

288

"""

289

Inject stateful scheduled hyperparameters into transformation.

290

291

Args:

292

transformation: Base gradient transformation

293

**scheduled_hyperparams: Named stateful schedule functions

294

295

Returns:

296

GradientTransformation with stateful scheduled hyperparameters

297

"""

298

```

299

300

### Schedule State Classes

301

302

```python { .api }

303

class InjectHyperparamsState:

304

"""State for hyperparameter injection."""

305

count: int

306

inner_state: OptState

307

308

class InjectStatefulHyperparamsState:

309

"""State for stateful hyperparameter injection."""

310

count: int

311

inner_state: OptState

312

hyperparams_states: dict

313

314

class WrappedSchedule:

315

"""Wrapper for schedule functions with state."""

316

schedule_fn: Schedule

317

```

318

319

## Usage Examples

320

321

### Basic Schedule Usage

322

323

```python

324

import optax

325

326

# Create different schedules

327

constant_lr = optax.constant_schedule(0.001)

328

linear_decay = optax.linear_schedule(0.001, 0.0001, 1000)

329

cosine_decay = optax.cosine_decay_schedule(0.001, 1000)

330

exponential_decay = optax.exponential_decay(0.001, 0.96, 100)

331

332

# Use schedule with optimizer

333

optimizer = optax.adam(learning_rate=cosine_decay)

334

335

# Evaluate schedule at different steps

336

step_0_lr = constant_lr(0) # 0.001

337

step_500_lr = linear_decay(500) # 0.0005

338

step_1000_lr = cosine_decay(1000) # close to 0

339

```

340

341

### Warmup Schedules

342

343

```python

344

# Warmup followed by cosine decay

345

warmup_cosine = optax.warmup_cosine_decay_schedule(

346

init_value=0.0,

347

peak_value=0.001,

348

warmup_steps=1000,

349

decay_steps=9000,

350

end_value=0.00001

351

)

352

353

# Warmup followed by constant

354

warmup_constant = optax.warmup_constant_schedule(

355

init_value=0.0,

356

peak_value=0.001,

357

warmup_steps=500

358

)

359

360

# Use with optimizer

361

optimizer = optax.adamw(learning_rate=warmup_cosine, weight_decay=0.01)

362

```

363

364

### Piecewise Schedules

365

366

```python

367

# Different learning rates at different training phases

368

boundaries_and_scales = {

369

500: 1.0, # LR = init_value * 1.0 until step 500

370

1000: 0.5, # LR = init_value * 0.5 from step 500-1000

371

1500: 0.1 # LR = init_value * 0.1 from step 1000-1500

372

}

373

374

piecewise_sched = optax.piecewise_constant_schedule(boundaries_and_scales)

375

376

# With interpolation

377

piecewise_interp = optax.piecewise_interpolate_schedule(

378

'linear', 0.001, boundaries_and_scales

379

)

380

```

381

382

### Advanced Scheduling

383

384

```python

385

# One-cycle schedule

386

onecycle = optax.cosine_onecycle_schedule(

387

transition_steps=5000,

388

peak_value=0.01,

389

pct_start=0.3, # 30% warmup

390

pct_final=0.85 # 85% before final decay

391

)

392

393

# SGDR with restarts

394

base_cosine = optax.cosine_decay_schedule(0.001, 1000)

395

sgdr = optax.sgdr_schedule(base_cosine, restart_period=1000, t_mult=2.0)

396

397

# Join multiple schedules

398

schedules = [

399

optax.constant_schedule(0.001), # First 1000 steps

400

optax.linear_schedule(0.001, 0.0001, 1000) # Next 1000 steps

401

]

402

joined = optax.join_schedules(schedules, [1000])

403

```

404

405

### Hyperparameter Scheduling

406

407

```python

408

# Schedule multiple hyperparameters

409

base_transform = optax.scale_by_adam()

410

411

scheduled_transform = optax.inject_hyperparams(

412

base_transform,

413

learning_rate=optax.cosine_decay_schedule(0.001, 1000),

414

b1=optax.linear_schedule(0.9, 0.95, 500),

415

b2=optax.constant_schedule(0.999)

416

)

417

418

# Create complete optimizer

419

optimizer = optax.chain(

420

scheduled_transform,

421

optax.scale(-1.0) # Apply negative learning rate

422

)

423

```

424

425

### Training Loop Integration

426

427

```python

428

import jax

429

430

# Create schedule

431

schedule = optax.warmup_cosine_decay_schedule(

432

init_value=0.0,

433

peak_value=0.001,

434

warmup_steps=1000,

435

decay_steps=9000

436

)

437

438

optimizer = optax.adam(learning_rate=schedule)

439

440

def train_step(params, opt_state, batch, step):

441

"""Training step with scheduled learning rate."""

442

443

def loss_fn(p):

444

return compute_loss(p, batch)

445

446

loss_val, grads = jax.value_and_grad(loss_fn)(params)

447

updates, opt_state = optimizer.update(grads, opt_state, params)

448

params = optax.apply_updates(params, updates)

449

450

# Current learning rate for logging

451

current_lr = schedule(step)

452

453

return params, opt_state, loss_val, current_lr

454

```