or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

contrib.mddocs/

0

# Experimental Optimizers (contrib)

1

2

The `optax.contrib` module contains experimental optimizers and techniques under active development. These are cutting-edge optimization methods that may not be as stable as the core optimizers but represent the latest research in optimization.

3

4

**Note**: Experimental features may have API changes in future versions.

5

6

## Capabilities

7

8

### Advanced Adaptive Optimizers

9

10

#### Sharpness-Aware Minimization (SAM)

11

12

```python { .api }

13

def sam(base_optimizer, rho=0.05, normalize=True):

14

"""

15

Sharpness-Aware Minimization optimizer.

16

17

Args:

18

base_optimizer: Base optimizer to use (e.g., SGD, Adam)

19

rho: Neighborhood size for sharpness computation (default: 0.05)

20

normalize: Whether to normalize perturbation (default: True)

21

22

Returns:

23

GradientTransformation: SAM optimizer

24

"""

25

```

26

27

#### Prodigy Optimizer

28

29

```python { .api }

30

def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0):

31

"""

32

Prodigy adaptive learning rate optimizer.

33

34

Args:

35

learning_rate: Initial learning rate (default: 1.0)

36

eps: Numerical stability parameter (default: 1e-8)

37

beta1: First moment decay rate (default: 0.9)

38

beta2: Second moment decay rate (default: 0.999)

39

weight_decay: Weight decay coefficient (default: 0.0)

40

41

Returns:

42

GradientTransformation: Prodigy optimizer

43

"""

44

```

45

46

#### Sophia Optimizer

47

48

```python { .api }

49

def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4):

50

"""

51

Sophia optimizer using second-order information.

52

53

Args:

54

learning_rate: Learning rate

55

beta1: First moment decay rate (default: 0.965)

56

beta2: Second moment decay rate (default: 0.99)

57

eps: Numerical stability parameter (default: 1e-8)

58

weight_decay: Weight decay coefficient (default: 1e-4)

59

60

Returns:

61

GradientTransformation: Sophia optimizer

62

"""

63

```

64

65

### Schedule-Free Optimizers

66

67

#### Schedule-Free AdamW

68

69

```python { .api }

70

def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0):

71

"""

72

Schedule-free AdamW optimizer that doesn't require learning rate schedules.

73

74

Args:

75

learning_rate: Learning rate (default: 0.0025)

76

beta1: First moment decay rate (default: 0.9)

77

beta2: Second moment decay rate (default: 0.999)

78

eps: Numerical stability parameter (default: 1e-8)

79

weight_decay: Weight decay coefficient (default: 0.0)

80

81

Returns:

82

GradientTransformation: Schedule-free AdamW optimizer

83

"""

84

85

def schedule_free_sgd(learning_rate=1.0, momentum=0.9, weight_decay=0.0):

86

"""

87

Schedule-free SGD optimizer.

88

89

Args:

90

learning_rate: Learning rate (default: 1.0)

91

momentum: Momentum coefficient (default: 0.9)

92

weight_decay: Weight decay coefficient (default: 0.0)

93

94

Returns:

95

GradientTransformation: Schedule-free SGD optimizer

96

"""

97

98

def schedule_free_eval_params(optimizer_state, step_count):

99

"""

100

Extract evaluation parameters from schedule-free optimizer state.

101

102

Args:

103

optimizer_state: State from schedule-free optimizer

104

step_count: Current training step count

105

106

Returns:

107

Parameters suitable for evaluation/inference

108

"""

109

```

110

111

### Momentum-Based Methods

112

113

#### Muon Optimizer

114

115

```python { .api }

116

def muon(learning_rate, momentum=0.95, nesterov=False):

117

"""

118

Muon optimizer with improved momentum handling.

119

120

Args:

121

learning_rate: Learning rate

122

momentum: Momentum coefficient (default: 0.95)

123

nesterov: Whether to use Nesterov momentum (default: False)

124

125

Returns:

126

GradientTransformation: Muon optimizer

127

"""

128

```

129

130

#### MoMo (Momentum Modulation)

131

132

```python { .api }

133

def momo(learning_rate, momentum=0.9):

134

"""

135

MoMo optimizer with momentum modulation.

136

137

Args:

138

learning_rate: Learning rate

139

momentum: Base momentum coefficient (default: 0.9)

140

141

Returns:

142

GradientTransformation: MoMo optimizer

143

"""

144

145

def momo_adam(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8):

146

"""

147

MoMo-Adam combining momentum modulation with Adam.

148

149

Args:

150

learning_rate: Learning rate

151

beta1: First moment decay rate (default: 0.9)

152

beta2: Second moment decay rate (default: 0.999)

153

eps: Numerical stability parameter (default: 1e-8)

154

155

Returns:

156

GradientTransformation: MoMo-Adam optimizer

157

"""

158

```

159

160

### Specialized Methods

161

162

#### DoG (Difference of Gaussians) and DoWG

163

164

```python { .api }

165

def dog(learning_rate, rho=0.05, eps=1e-8):

166

"""

167

DoG (Difference of Gaussians) optimizer.

168

169

Args:

170

learning_rate: Learning rate

171

rho: Difference parameter (default: 0.05)

172

eps: Numerical stability parameter (default: 1e-8)

173

174

Returns:

175

GradientTransformation: DoG optimizer

176

"""

177

178

def dowg(learning_rate, rho=0.05, eps=1e-8, weight_decay=0.0):

179

"""

180

DoWG (DoG with Weight decay) optimizer.

181

182

Args:

183

learning_rate: Learning rate

184

rho: Difference parameter (default: 0.05)

185

eps: Numerical stability parameter (default: 1e-8)

186

weight_decay: Weight decay coefficient (default: 0.0)

187

188

Returns:

189

GradientTransformation: DoWG optimizer

190

"""

191

```

192

193

#### ADOPT

194

195

```python { .api }

196

def adopt(learning_rate, eps=1e-8, beta1=0.9, beta2=0.9999, weight_decay=0.0):

197

"""

198

ADOPT optimizer with adaptive learning rates.

199

200

Args:

201

learning_rate: Learning rate

202

eps: Numerical stability parameter (default: 1e-8)

203

beta1: First moment decay rate (default: 0.9)

204

beta2: Second moment decay rate (default: 0.9999)

205

weight_decay: Weight decay coefficient (default: 0.0)

206

207

Returns:

208

GradientTransformation: ADOPT optimizer

209

"""

210

```

211

212

### Privacy-Preserving Methods

213

214

#### Differential Privacy

215

216

```python { .api }

217

def differentially_private_aggregate(

218

inner_agg_factory,

219

l2_norm_bound,

220

noise_multiplier,

221

seed=None

222

):

223

"""

224

Differentially private gradient aggregation.

225

226

Args:

227

inner_agg_factory: Base aggregation function

228

l2_norm_bound: L2 norm bound for gradient clipping

229

noise_multiplier: Noise multiplier for privacy

230

seed: Random seed (default: None)

231

232

Returns:

233

GradientTransformation: DP aggregation function

234

"""

235

```

236

237

### Experimental Adaptive Methods

238

239

#### AdEMAMix

240

241

```python { .api }

242

def ademamix(learning_rate, beta1=0.9, beta2=0.999, eps=1e-8, alpha=5.0):

243

"""

244

AdEMAMix optimizer with exponential moving average mixing.

245

246

Args:

247

learning_rate: Learning rate

248

beta1: First moment decay rate (default: 0.9)

249

beta2: Second moment decay rate (default: 0.999)

250

eps: Numerical stability parameter (default: 1e-8)

251

alpha: Mixing parameter (default: 5.0)

252

253

Returns:

254

GradientTransformation: AdEMAMix optimizer

255

"""

256

```

257

258

#### COCOB

259

260

```python { .api }

261

def cocob():

262

"""

263

COCOB (Coin-flipping Online Convex Optimization with Budget) optimizer.

264

265

Returns:

266

GradientTransformation: COCOB optimizer (parameter-free)

267

"""

268

```

269

270

## Usage Examples

271

272

```python

273

import optax

274

275

# Using SAM for better generalization

276

base_optimizer = optax.sgd(0.1)

277

sam_optimizer = optax.contrib.sam(base_optimizer, rho=0.05)

278

279

# Using schedule-free optimizers

280

sf_adamw = optax.contrib.schedule_free_adamw(learning_rate=0.001)

281

282

# Using experimental adaptive methods

283

prodigy_opt = optax.contrib.prodigy(learning_rate=1.0)

284

sophia_opt = optax.contrib.sophia(learning_rate=0.001)

285

286

# Training loop with schedule-free optimizer

287

opt_state = sf_adamw.init(params)

288

for step in range(num_steps):

289

grads = compute_gradients(params, data)

290

updates, opt_state = sf_adamw.update(grads, opt_state, params)

291

params = optax.apply_updates(params, updates)

292

293

# Extract evaluation parameters (for schedule-free methods)

294

if step % eval_interval == 0:

295

eval_params = optax.contrib.schedule_free_eval_params(opt_state, step)

296

eval_loss = evaluate(eval_params, eval_data)

297

```

298

299

## Import

300

301

```python

302

import optax.contrib

303

# or

304

from optax.contrib import sam, prodigy, schedule_free_adamw

305

```

306

307

## Research Papers

308

309

Many contrib optimizers are based on recent research:

310

311

- **SAM**: "Sharpness-Aware Minimization for Efficiently Improving Generalization"

312

- **Prodigy**: "Prodigy: An Expeditiously Adaptive Parameter-Free Learner"

313

- **Sophia**: "Sophia: A Scalable Stochastic Second-order Optimizer"

314

- **Schedule-Free**: "The Road Less Scheduled"

315

- **AdEMAMix**: "The AdEMAMix Optimizer: Better, Faster, Older"

316

317

Refer to the respective papers for detailed algorithmic descriptions and theoretical analysis.