or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

advanced-optimizers.mddocs/

0

# Advanced Optimizers

1

2

Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers. These optimizers implement cutting-edge techniques and may require more careful tuning than core optimizers.

3

4

## Capabilities

5

6

### Lion Optimizer

7

8

Lion (Evolved Sign Momentum) optimizer that uses sign-based updates for memory efficiency and competitive performance.

9

10

```python { .api }

11

def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0):

12

"""

13

Lion optimizer (Evolved Sign Momentum).

14

15

Args:

16

learning_rate: Learning rate or schedule

17

b1: Exponential decay rate for momentum (default: 0.9)

18

b2: Exponential decay rate for moving average (default: 0.99)

19

weight_decay: Weight decay coefficient (default: 0.0)

20

21

Returns:

22

GradientTransformation

23

"""

24

```

25

26

### LARS Optimizer

27

28

Layer-wise Adaptive Rate Scaling (LARS) optimizer for large batch training.

29

30

```python { .api }

31

def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.):

32

"""

33

LARS (Layer-wise Adaptive Rate Scaling) optimizer.

34

35

Args:

36

learning_rate: Learning rate or schedule

37

weight_decay: Weight decay coefficient (default: 0.0)

38

trust_coefficient: Trust coefficient for layer-wise adaptation (default: 0.001)

39

eps: Small constant for numerical stability (default: 0.0)

40

41

Returns:

42

GradientTransformation

43

"""

44

```

45

46

### LAMB Optimizer

47

48

Layer-wise Adaptive Moments optimizer for Batch training, designed for large batch sizes.

49

50

```python { .api }

51

def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None):

52

"""

53

LAMB (Layer-wise Adaptive Moments optimizer for Batch training) optimizer.

54

55

Args:

56

learning_rate: Learning rate or schedule

57

b1: Exponential decay rate for first moment estimates (default: 0.9)

58

b2: Exponential decay rate for second moment estimates (default: 0.999)

59

eps: Small constant for numerical stability (default: 1e-6)

60

weight_decay: Weight decay coefficient (default: 0.0)

61

mask: Optional mask for parameter selection

62

63

Returns:

64

GradientTransformation

65

"""

66

```

67

68

### L-BFGS Optimizer

69

70

Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.

71

72

```python { .api }

73

def lbfgs(learning_rate, memory_size=10, scale_init_preconditioner=True):

74

"""

75

L-BFGS quasi-Newton optimizer.

76

77

Args:

78

learning_rate: Learning rate or schedule

79

memory_size: Number of previous gradients to store (default: 10)

80

scale_init_preconditioner: Whether to scale initial preconditioner (default: True)

81

82

Returns:

83

GradientTransformation

84

"""

85

```

86

87

### Yogi Optimizer

88

89

Yogi optimizer that controls the increase in effective learning rate to avoid rapid convergence.

90

91

```python { .api }

92

def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6):

93

"""

94

Yogi optimizer.

95

96

Args:

97

learning_rate: Learning rate or schedule

98

b1: Exponential decay rate for first moment estimates (default: 0.9)

99

b2: Exponential decay rate for second moment estimates (default: 0.999)

100

eps: Small constant for numerical stability (default: 1e-3)

101

initial_accumulator: Initial value for accumulator (default: 1e-6)

102

103

Returns:

104

GradientTransformation

105

"""

106

```

107

108

### NovoGrad Optimizer

109

110

NovoGrad optimizer that combines adaptive learning rates with gradient normalization.

111

112

```python { .api }

113

def novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-6, weight_decay=0.):

114

"""

115

NovoGrad optimizer.

116

117

Args:

118

learning_rate: Learning rate or schedule

119

b1: Exponential decay rate for first moment estimates (default: 0.9)

120

b2: Exponential decay rate for second moment estimates (default: 0.25)

121

eps: Small constant for numerical stability (default: 1e-6)

122

weight_decay: Weight decay coefficient (default: 0.0)

123

124

Returns:

125

GradientTransformation

126

"""

127

```

128

129

### RAdam Optimizer

130

131

Rectified Adam optimizer that addresses the variance issue in early training stages.

132

133

```python { .api }

134

def radam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, threshold=5.0):

135

"""

136

RAdam (Rectified Adam) optimizer.

137

138

Args:

139

learning_rate: Learning rate or schedule

140

b1: Exponential decay rate for first moment estimates (default: 0.9)

141

b2: Exponential decay rate for second moment estimates (default: 0.999)

142

eps: Small constant for numerical stability (default: 1e-8)

143

threshold: Threshold for variance tractability (default: 5.0)

144

145

Returns:

146

GradientTransformation

147

"""

148

```

149

150

### SM3 Optimizer

151

152

SM3 optimizer designed for sparse gradients with memory-efficient second moments.

153

154

```python { .api }

155

def sm3(learning_rate, momentum=0.9):

156

"""

157

SM3 optimizer for sparse gradients.

158

159

Args:

160

learning_rate: Learning rate or schedule

161

momentum: Momentum coefficient (default: 0.9)

162

163

Returns:

164

GradientTransformation

165

"""

166

```

167

168

### Fromage Optimizer

169

170

Frobenius matched gradient descent optimizer.

171

172

```python { .api }

173

def fromage(learning_rate):

174

"""

175

Fromage (Frobenius matched gradient descent) optimizer.

176

177

Args:

178

learning_rate: Learning rate or schedule

179

180

Returns:

181

GradientTransformation

182

"""

183

```

184

185

### Specialized SGD Variants

186

187

#### Noisy SGD

188

189

SGD with gradient noise injection for improved generalization.

190

191

```python { .api }

192

def noisy_sgd(learning_rate, eta=0.01):

193

"""

194

Noisy SGD with gradient noise injection.

195

196

Args:

197

learning_rate: Learning rate or schedule

198

eta: Noise scaling parameter (default: 0.01)

199

200

Returns:

201

GradientTransformation

202

"""

203

```

204

205

#### Sign SGD

206

207

SGD using only the sign of gradients.

208

209

```python { .api }

210

def sign_sgd(learning_rate):

211

"""

212

Sign SGD optimizer using gradient signs only.

213

214

Args:

215

learning_rate: Learning rate or schedule

216

217

Returns:

218

GradientTransformation

219

"""

220

```

221

222

#### Polyak SGD

223

224

SGD with Polyak momentum.

225

226

```python { .api }

227

def polyak_sgd(learning_rate, polyak_momentum=0.9):

228

"""

229

SGD with Polyak momentum.

230

231

Args:

232

learning_rate: Learning rate or schedule

233

polyak_momentum: Polyak momentum coefficient (default: 0.9)

234

235

Returns:

236

GradientTransformation

237

"""

238

```

239

240

### RProp Optimizer

241

242

Resilient backpropagation optimizer that uses only gradient signs.

243

244

```python { .api }

245

def rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-6, max_step_size=50.):

246

"""

247

RProp (Resilient backpropagation) optimizer.

248

249

Args:

250

learning_rate: Initial step size

251

eta_minus: Factor for decreasing step size (default: 0.5)

252

eta_plus: Factor for increasing step size (default: 1.2)

253

min_step_size: Minimum step size (default: 1e-6)

254

max_step_size: Maximum step size (default: 50.0)

255

256

Returns:

257

GradientTransformation

258

"""

259

```

260

261

### Optimistic Methods

262

263

#### Optimistic Gradient Descent

264

265

Optimistic gradient descent for saddle point problems.

266

267

```python { .api }

268

def optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0):

269

"""

270

Optimistic gradient descent.

271

272

Args:

273

learning_rate: Learning rate or schedule

274

alpha: Extrapolation coefficient (default: 1.0)

275

beta: Update coefficient (default: 1.0)

276

277

Returns:

278

GradientTransformation

279

"""

280

```

281

282

#### Optimistic Adam

283

284

Optimistic variant of Adam optimizer.

285

286

```python { .api }

287

def optimistic_adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):

288

"""

289

Optimistic Adam optimizer.

290

291

Args:

292

learning_rate: Learning rate or schedule

293

b1: Exponential decay rate for first moment estimates (default: 0.9)

294

b2: Exponential decay rate for second moment estimates (default: 0.999)

295

eps: Small constant for numerical stability (default: 1e-8)

296

297

Returns:

298

GradientTransformation

299

"""

300

```

301

302

### Lookahead Wrapper

303

304

Lookahead optimizer that can wrap any base optimizer.

305

306

```python { .api }

307

def lookahead(fast_optimizer, lookahead_steps=5, lookahead_alpha=0.5):

308

"""

309

Lookahead optimizer wrapper.

310

311

Args:

312

fast_optimizer: Base optimizer to wrap

313

lookahead_steps: Number of fast optimizer steps before lookahead (default: 5)

314

lookahead_alpha: Interpolation factor for lookahead (default: 0.5)

315

316

Returns:

317

GradientTransformation

318

"""

319

```

320

321

## Usage Example

322

323

```python

324

import optax

325

import jax.numpy as jnp

326

327

# Initialize parameters

328

params = {'weights': jnp.ones((100, 50)), 'bias': jnp.zeros((50,))}

329

330

# Advanced optimizers for different scenarios

331

lion_opt = optax.lion(learning_rate=0.0001) # Memory efficient

332

lars_opt = optax.lars(learning_rate=0.01) # Large batch training

333

lamb_opt = optax.lamb(learning_rate=0.001) # Large batch training

334

lbfgs_opt = optax.lbfgs(learning_rate=1.0) # Second-order method

335

336

# Lookahead wrapper

337

base_opt = optax.adam(learning_rate=0.001)

338

lookahead_opt = optax.lookahead(base_opt, lookahead_steps=5)

339

340

# Initialize states

341

lion_state = lion_opt.init(params)

342

lookahead_state = lookahead_opt.init(params)

343

344

# Usage in training loop

345

def training_step(params, opt_state, gradients, optimizer):

346

updates, new_opt_state = optimizer.update(gradients, opt_state)

347

new_params = optax.apply_updates(params, updates)

348

return new_params, new_opt_state

349

```