or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

index.mddocs/

0

# Optax

1

2

A gradient processing and optimization library in JAX. Optax provides modular building blocks that can be easily recombined to create custom optimizers and gradient processing components. The library offers implementations of many popular optimizers, loss functions, and gradient transformations with a focus on composability and research productivity.

3

4

## Package Information

5

6

- **Package Name**: optax

7

- **Language**: Python

8

- **Installation**: `pip install optax`

9

- **Documentation**: https://optax.readthedocs.io/

10

11

## Core Imports

12

13

```python

14

import optax

15

```

16

17

Common usage patterns:

18

19

```python

20

# Import specific optimizers

21

from optax import adam, sgd, adamw

22

23

# Import transformations and utilities

24

from optax import apply_updates, chain

25

26

# Import loss functions

27

from optax import l2_loss, softmax_cross_entropy

28

29

# Import schedules

30

from optax import linear_schedule, cosine_decay_schedule

31

```

32

33

## Basic Usage

34

35

```python

36

import jax

37

import jax.numpy as jnp

38

import optax

39

40

# Initialize model parameters

41

params = {'w': jnp.ones((10,)), 'b': jnp.zeros((1,))}

42

43

# Create an optimizer

44

optimizer = optax.adam(learning_rate=0.001)

45

46

# Initialize optimizer state

47

opt_state = optimizer.init(params)

48

49

# Define a simple loss function

50

def loss_fn(params, x, y):

51

pred = params['w'].dot(x) + params['b']

52

return optax.l2_loss(pred, y)

53

54

# Training step

55

def train_step(params, opt_state, x, y):

56

# Compute gradients

57

grads = jax.grad(loss_fn)(params, x, y)

58

59

# Update parameters

60

updates, opt_state = optimizer.update(grads, opt_state)

61

params = optax.apply_updates(params, updates)

62

63

return params, opt_state

64

65

# Example training data

66

x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])

67

y = jnp.array([2.0])

68

69

# Perform training step

70

params, opt_state = train_step(params, opt_state, x, y)

71

```

72

73

## Architecture

74

75

Optax is built around three key concepts:

76

77

- **GradientTransformation**: Core abstraction with `init` and `update` functions that process gradients

78

- **Composability**: Transformations can be chained together using `optax.chain()` to create custom optimizers

79

- **Modularity**: Small building blocks that can be recombined in custom ways for research flexibility

80

81

The library provides implementations at multiple levels of abstraction:

82

- High-level optimizers (adam, sgd, etc.) that are ready to use

83

- Mid-level gradient transformations that can be combined

84

- Low-level utilities for building custom components

85

86

## Capabilities

87

88

### Core Optimizers

89

90

Popular optimization algorithms including Adam, SGD, RMSprop, Adagrad, and many others. These are complete optimizers ready for immediate use in training loops.

91

92

```python { .api }

93

def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...

94

def sgd(learning_rate, momentum=None, nesterov=False): ...

95

def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False): ...

96

def rmsprop(learning_rate, decay=0.9, eps=1e-8): ...

97

def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7): ...

98

```

99

100

[Core Optimizers](./optimizers.md)

101

102

### Advanced Optimizers

103

104

Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers.

105

106

```python { .api }

107

def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0): ...

108

def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.): ...

109

def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None): ...

110

def lbfgs(learning_rate, ...): ...

111

def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6): ...

112

```

113

114

[Advanced Optimizers](./advanced-optimizers.md)

115

116

### Gradient Transformations

117

118

Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These can be combined using `chain()` to build custom optimization strategies.

119

120

```python { .api }

121

def scale(step_size): ...

122

def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...

123

def clip_by_global_norm(max_norm): ...

124

def add_decayed_weights(weight_decay, mask=None): ...

125

def trace(decay, nesterov=False, accumulator_dtype=None): ...

126

def chain(*transformations): ...

127

```

128

129

[Gradient Transformations](./transformations.md)

130

131

### Loss Functions

132

133

Comprehensive collection of loss functions for classification, regression, and structured prediction tasks.

134

135

```python { .api }

136

def l2_loss(predictions, targets): ...

137

def softmax_cross_entropy(logits, labels, axis=-1): ...

138

def sigmoid_binary_cross_entropy(logits, labels): ...

139

def huber_loss(predictions, targets, delta=1.0): ...

140

def hinge_loss(scores, labels): ...

141

```

142

143

[Loss Functions](./losses.md)

144

145

### Learning Rate Schedules

146

147

Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules.

148

149

```python { .api }

150

def constant_schedule(value): ...

151

def linear_schedule(init_value, end_value, transition_steps): ...

152

def cosine_decay_schedule(init_value, decay_steps, alpha=0.0): ...

153

def exponential_decay(init_value, decay_rate, transition_steps, ...): ...

154

def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value): ...

155

```

156

157

[Schedules](./schedules.md)

158

159

### Utilities and Tree Operations

160

161

Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees.

162

163

```python { .api }

164

def apply_updates(params, updates): ...

165

def global_norm(updates): ...

166

def safe_norm(x, min_norm=0.0, ord=None): ...

167

class GradientTransformation: ...

168

class OptState: ...

169

class Params: ...

170

```

171

172

[Utilities](./utilities.md)

173

174

### Assignment Operations

175

176

Linear assignment algorithms including the Hungarian algorithm for solving optimal assignment problems.

177

178

```python { .api }

179

def hungarian_algorithm(cost_matrix): ...

180

def base_hungarian_algorithm(cost_matrix): ...

181

```

182

183

[Assignment Operations](./assignment.md)

184

185

### Monte Carlo Gradient Estimation

186

187

Utilities for Monte Carlo gradient estimation methods including score function, pathwise, and measure-valued estimators. **Note**: These functions are deprecated and will be removed in version 0.3.0.

188

189

```python { .api }

190

def score_function_jacobians(function, params, dist_builder, rng, num_samples): ...

191

def pathwise_jacobians(function, params, dist_builder, rng, num_samples): ...

192

def measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True): ...

193

```

194

195

[Monte Carlo Methods](./monte-carlo.md)

196

197

### Perturbation-Based Optimization

198

199

Utilities for making non-differentiable functions differentiable through stochastic perturbations.

200

201

```python { .api }

202

def make_perturbed_fun(fun, num_samples=1000, sigma=0.1, noise=Gumbel(), use_baseline=True): ...

203

class Gumbel: ...

204

class Normal: ...

205

```

206

207

[Perturbations](./perturbations.md)

208

209

### Constraint Projections

210

211

Projection functions for enforcing constraints in optimization by projecting parameters onto feasible sets.

212

213

```python { .api }

214

def projection_l2_ball(params, radius=1.0): ...

215

def projection_simplex(params): ...

216

def projection_box(params, lower=None, upper=None): ...

217

```

218

219

[Projections](./projections.md)

220

221

### Second-Order Methods

222

223

Utilities for second-order optimization including Hessian computations and Fisher information.

224

225

```python { .api }

226

def hessian_diag(fun): ...

227

def fisher_diag(log_likelihood): ...

228

def hvp(fun, primals, tangents): ...

229

```

230

231

[Second-Order Methods](./second-order.md)

232

233

### Tree Utilities

234

235

JAX PyTree manipulation utilities for working with nested parameter structures.

236

237

```python { .api }

238

def tree_add(tree_a, tree_b): ...

239

def tree_scale(tree, scalar): ...

240

def tree_zeros_like(tree): ...

241

```

242

243

[Tree Utilities](./tree-utilities.md)

244

245

### Experimental Features

246

247

The `optax.contrib` module contains experimental optimizers and techniques under active development, including SAM, Prodigy, Sophia, and schedule-free optimizers.

248

249

```python { .api }

250

# Sharpness-Aware Minimization

251

def sam(base_optimizer, rho=0.05, normalize=True): ...

252

253

# Advanced adaptive optimizers

254

def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0): ...

255

def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4): ...

256

257

# Schedule-free optimizers

258

def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0): ...

259

```

260

261

[Experimental Optimizers](./contrib.md)

262

263

## Types

264

265

```python { .api }

266

# Core type aliases

267

OptState = chex.ArrayTree # Optimizer state

268

Params = chex.ArrayTree # Model parameters

269

Updates = Params # Gradient updates

270

Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function

271

ScalarOrSchedule = Union[float, jax.Array, Schedule]

272

273

# Core classes

274

class GradientTransformation(NamedTuple):

275

init: Callable[[Params], OptState]

276

update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]

277

278

class GradientTransformationExtraArgs(NamedTuple):

279

init: Callable[[Params], OptState]

280

update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]

281

282

class EmptyState(NamedTuple):

283

"""Empty state for stateless transformations"""

284

pass

285

286

# Transformation function types

287

TransformInitFn = Callable[[Params], OptState]

288

TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]

289

TransformUpdateExtraArgsFn = Callable[..., Tuple[Updates, OptState]]

290

291

# Optimizer state classes

292

class ScaleByAdamState(NamedTuple):

293

count: chex.Array

294

mu: Updates

295

nu: Updates

296

297

class ScaleByRmsState(NamedTuple):

298

count: chex.Array

299

nu: Updates

300

301

class ScaleByScheduleState(NamedTuple):

302

count: chex.Array

303

304

class FactoredState(NamedTuple):

305

v_row: chex.Array

306

v_col: chex.Array

307

v: chex.Array

308

309

class LookaheadParams(NamedTuple):

310

slow: Params

311

LookaheadState = LookaheadParams

312

313

class ApplyEvery(NamedTuple):

314

count: chex.Array

315

grad_acc: Updates

316

317

# Tree and projection types

318

MaskOrFn = Union[chex.Array, Callable[[Params], chex.Array]]

319

MaskedNode = Any

320

321

# Schedule types

322

WrappedSchedule = Callable[[chex.Numeric], chex.Numeric]

323

324

# Assignment types (from optax.assignment)

325

CostMatrix = chex.Array

326

Assignment = Tuple[chex.Array, chex.Array] # (row_indices, col_indices)

327

328

# Monte Carlo types (from optax.monte_carlo) - deprecated

329

ControlVariate = Tuple[Callable, Callable, Callable]

330

CvState = Any

331

332

# Perturbation types (from optax.perturbations)

333

NoiseDistribution = Any # Objects with sample() and log_prob() methods

334

335

# Contrib optimizer state classes (experimental)

336

class ScaleByAdemamixState(NamedTuple):

337

count: chex.Array

338

mu: Updates

339

nu: Updates

340

341

class MuonState(NamedTuple):

342

momentum: Updates

343

344

class COCOBState(NamedTuple):

345

sum_grad_squared: Updates

346

sum_grad: Updates

347

348

class DoGState(NamedTuple):

349

momentum: Updates

350

351

# Additional state classes for contrib optimizers

352

DAdaptAdamWState = Any

353

MechanicState = Any

354

MomoState = Any

355

MomoAdamState = Any

356

DoWGState = Any

357

ScaleBySimplifiedAdEMAMixState = Any

358

DifferentiallyPrivateAggregateState = Any

359

360

# Linesearch types

361

class ScaleByBacktrackingLinesearchState(NamedTuple):

362

count: chex.Array

363

f_eval: chex.Array

364

365

class ScaleByZoomLinesearchState(NamedTuple):

366

count: chex.Array

367

f_eval: chex.Array

368

369

class ZoomLinesearchInfo(NamedTuple):

370

failed: bool

371

nfev: int

372

ngev: int

373

k: int

374

```