A gradient processing and optimization library in JAX
npx @tessl/cli install tessl/pypi-optax@0.2.00
# Optax
1
2
A gradient processing and optimization library in JAX. Optax provides modular building blocks that can be easily recombined to create custom optimizers and gradient processing components. The library offers implementations of many popular optimizers, loss functions, and gradient transformations with a focus on composability and research productivity.
3
4
## Package Information
5
6
- **Package Name**: optax
7
- **Language**: Python
8
- **Installation**: `pip install optax`
9
- **Documentation**: https://optax.readthedocs.io/
10
11
## Core Imports
12
13
```python
14
import optax
15
```
16
17
Common usage patterns:
18
19
```python
20
# Import specific optimizers
21
from optax import adam, sgd, adamw
22
23
# Import transformations and utilities
24
from optax import apply_updates, chain
25
26
# Import loss functions
27
from optax import l2_loss, softmax_cross_entropy
28
29
# Import schedules
30
from optax import linear_schedule, cosine_decay_schedule
31
```
32
33
## Basic Usage
34
35
```python
36
import jax
37
import jax.numpy as jnp
38
import optax
39
40
# Initialize model parameters
41
params = {'w': jnp.ones((10,)), 'b': jnp.zeros((1,))}
42
43
# Create an optimizer
44
optimizer = optax.adam(learning_rate=0.001)
45
46
# Initialize optimizer state
47
opt_state = optimizer.init(params)
48
49
# Define a simple loss function
50
def loss_fn(params, x, y):
51
pred = params['w'].dot(x) + params['b']
52
return optax.l2_loss(pred, y)
53
54
# Training step
55
def train_step(params, opt_state, x, y):
56
# Compute gradients
57
grads = jax.grad(loss_fn)(params, x, y)
58
59
# Update parameters
60
updates, opt_state = optimizer.update(grads, opt_state)
61
params = optax.apply_updates(params, updates)
62
63
return params, opt_state
64
65
# Example training data
66
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
67
y = jnp.array([2.0])
68
69
# Perform training step
70
params, opt_state = train_step(params, opt_state, x, y)
71
```
72
73
## Architecture
74
75
Optax is built around three key concepts:
76
77
- **GradientTransformation**: Core abstraction with `init` and `update` functions that process gradients
78
- **Composability**: Transformations can be chained together using `optax.chain()` to create custom optimizers
79
- **Modularity**: Small building blocks that can be recombined in custom ways for research flexibility
80
81
The library provides implementations at multiple levels of abstraction:
82
- High-level optimizers (adam, sgd, etc.) that are ready to use
83
- Mid-level gradient transformations that can be combined
84
- Low-level utilities for building custom components
85
86
## Capabilities
87
88
### Core Optimizers
89
90
Popular optimization algorithms including Adam, SGD, RMSprop, Adagrad, and many others. These are complete optimizers ready for immediate use in training loops.
91
92
```python { .api }
93
def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
94
def sgd(learning_rate, momentum=None, nesterov=False): ...
95
def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False): ...
96
def rmsprop(learning_rate, decay=0.9, eps=1e-8): ...
97
def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7): ...
98
```
99
100
[Core Optimizers](./optimizers.md)
101
102
### Advanced Optimizers
103
104
Specialized and experimental optimization algorithms including second-order methods, adaptive variants, and research optimizers.
105
106
```python { .api }
107
def lion(learning_rate, b1=0.9, b2=0.99, weight_decay=0.0): ...
108
def lars(learning_rate, weight_decay=0., trust_coefficient=0.001, eps=0.): ...
109
def lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0., mask=None): ...
110
def lbfgs(learning_rate, ...): ...
111
def yogi(learning_rate, b1=0.9, b2=0.999, eps=1e-3, initial_accumulator=1e-6): ...
112
```
113
114
[Advanced Optimizers](./advanced-optimizers.md)
115
116
### Gradient Transformations
117
118
Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These can be combined using `chain()` to build custom optimization strategies.
119
120
```python { .api }
121
def scale(step_size): ...
122
def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False): ...
123
def clip_by_global_norm(max_norm): ...
124
def add_decayed_weights(weight_decay, mask=None): ...
125
def trace(decay, nesterov=False, accumulator_dtype=None): ...
126
def chain(*transformations): ...
127
```
128
129
[Gradient Transformations](./transformations.md)
130
131
### Loss Functions
132
133
Comprehensive collection of loss functions for classification, regression, and structured prediction tasks.
134
135
```python { .api }
136
def l2_loss(predictions, targets): ...
137
def softmax_cross_entropy(logits, labels, axis=-1): ...
138
def sigmoid_binary_cross_entropy(logits, labels): ...
139
def huber_loss(predictions, targets, delta=1.0): ...
140
def hinge_loss(scores, labels): ...
141
```
142
143
[Loss Functions](./losses.md)
144
145
### Learning Rate Schedules
146
147
Flexible scheduling functions for learning rates and other hyperparameters including warmup, decay, and cyclic schedules.
148
149
```python { .api }
150
def constant_schedule(value): ...
151
def linear_schedule(init_value, end_value, transition_steps): ...
152
def cosine_decay_schedule(init_value, decay_steps, alpha=0.0): ...
153
def exponential_decay(init_value, decay_rate, transition_steps, ...): ...
154
def warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value): ...
155
```
156
157
[Schedules](./schedules.md)
158
159
### Utilities and Tree Operations
160
161
Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees.
162
163
```python { .api }
164
def apply_updates(params, updates): ...
165
def global_norm(updates): ...
166
def safe_norm(x, min_norm=0.0, ord=None): ...
167
class GradientTransformation: ...
168
class OptState: ...
169
class Params: ...
170
```
171
172
[Utilities](./utilities.md)
173
174
### Assignment Operations
175
176
Linear assignment algorithms including the Hungarian algorithm for solving optimal assignment problems.
177
178
```python { .api }
179
def hungarian_algorithm(cost_matrix): ...
180
def base_hungarian_algorithm(cost_matrix): ...
181
```
182
183
[Assignment Operations](./assignment.md)
184
185
### Monte Carlo Gradient Estimation
186
187
Utilities for Monte Carlo gradient estimation methods including score function, pathwise, and measure-valued estimators. **Note**: These functions are deprecated and will be removed in version 0.3.0.
188
189
```python { .api }
190
def score_function_jacobians(function, params, dist_builder, rng, num_samples): ...
191
def pathwise_jacobians(function, params, dist_builder, rng, num_samples): ...
192
def measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True): ...
193
```
194
195
[Monte Carlo Methods](./monte-carlo.md)
196
197
### Perturbation-Based Optimization
198
199
Utilities for making non-differentiable functions differentiable through stochastic perturbations.
200
201
```python { .api }
202
def make_perturbed_fun(fun, num_samples=1000, sigma=0.1, noise=Gumbel(), use_baseline=True): ...
203
class Gumbel: ...
204
class Normal: ...
205
```
206
207
[Perturbations](./perturbations.md)
208
209
### Constraint Projections
210
211
Projection functions for enforcing constraints in optimization by projecting parameters onto feasible sets.
212
213
```python { .api }
214
def projection_l2_ball(params, radius=1.0): ...
215
def projection_simplex(params): ...
216
def projection_box(params, lower=None, upper=None): ...
217
```
218
219
[Projections](./projections.md)
220
221
### Second-Order Methods
222
223
Utilities for second-order optimization including Hessian computations and Fisher information.
224
225
```python { .api }
226
def hessian_diag(fun): ...
227
def fisher_diag(log_likelihood): ...
228
def hvp(fun, primals, tangents): ...
229
```
230
231
[Second-Order Methods](./second-order.md)
232
233
### Tree Utilities
234
235
JAX PyTree manipulation utilities for working with nested parameter structures.
236
237
```python { .api }
238
def tree_add(tree_a, tree_b): ...
239
def tree_scale(tree, scalar): ...
240
def tree_zeros_like(tree): ...
241
```
242
243
[Tree Utilities](./tree-utilities.md)
244
245
### Experimental Features
246
247
The `optax.contrib` module contains experimental optimizers and techniques under active development, including SAM, Prodigy, Sophia, and schedule-free optimizers.
248
249
```python { .api }
250
# Sharpness-Aware Minimization
251
def sam(base_optimizer, rho=0.05, normalize=True): ...
252
253
# Advanced adaptive optimizers
254
def prodigy(learning_rate=1.0, eps=1e-8, beta1=0.9, beta2=0.999, weight_decay=0.0): ...
255
def sophia(learning_rate, beta1=0.965, beta2=0.99, eps=1e-8, weight_decay=1e-4): ...
256
257
# Schedule-free optimizers
258
def schedule_free_adamw(learning_rate=0.0025, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0): ...
259
```
260
261
[Experimental Optimizers](./contrib.md)
262
263
## Types
264
265
```python { .api }
266
# Core type aliases
267
OptState = chex.ArrayTree # Optimizer state
268
Params = chex.ArrayTree # Model parameters
269
Updates = Params # Gradient updates
270
Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function
271
ScalarOrSchedule = Union[float, jax.Array, Schedule]
272
273
# Core classes
274
class GradientTransformation(NamedTuple):
275
init: Callable[[Params], OptState]
276
update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
277
278
class GradientTransformationExtraArgs(NamedTuple):
279
init: Callable[[Params], OptState]
280
update: Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
281
282
class EmptyState(NamedTuple):
283
"""Empty state for stateless transformations"""
284
pass
285
286
# Transformation function types
287
TransformInitFn = Callable[[Params], OptState]
288
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
289
TransformUpdateExtraArgsFn = Callable[..., Tuple[Updates, OptState]]
290
291
# Optimizer state classes
292
class ScaleByAdamState(NamedTuple):
293
count: chex.Array
294
mu: Updates
295
nu: Updates
296
297
class ScaleByRmsState(NamedTuple):
298
count: chex.Array
299
nu: Updates
300
301
class ScaleByScheduleState(NamedTuple):
302
count: chex.Array
303
304
class FactoredState(NamedTuple):
305
v_row: chex.Array
306
v_col: chex.Array
307
v: chex.Array
308
309
class LookaheadParams(NamedTuple):
310
slow: Params
311
LookaheadState = LookaheadParams
312
313
class ApplyEvery(NamedTuple):
314
count: chex.Array
315
grad_acc: Updates
316
317
# Tree and projection types
318
MaskOrFn = Union[chex.Array, Callable[[Params], chex.Array]]
319
MaskedNode = Any
320
321
# Schedule types
322
WrappedSchedule = Callable[[chex.Numeric], chex.Numeric]
323
324
# Assignment types (from optax.assignment)
325
CostMatrix = chex.Array
326
Assignment = Tuple[chex.Array, chex.Array] # (row_indices, col_indices)
327
328
# Monte Carlo types (from optax.monte_carlo) - deprecated
329
ControlVariate = Tuple[Callable, Callable, Callable]
330
CvState = Any
331
332
# Perturbation types (from optax.perturbations)
333
NoiseDistribution = Any # Objects with sample() and log_prob() methods
334
335
# Contrib optimizer state classes (experimental)
336
class ScaleByAdemamixState(NamedTuple):
337
count: chex.Array
338
mu: Updates
339
nu: Updates
340
341
class MuonState(NamedTuple):
342
momentum: Updates
343
344
class COCOBState(NamedTuple):
345
sum_grad_squared: Updates
346
sum_grad: Updates
347
348
class DoGState(NamedTuple):
349
momentum: Updates
350
351
# Additional state classes for contrib optimizers
352
DAdaptAdamWState = Any
353
MechanicState = Any
354
MomoState = Any
355
MomoAdamState = Any
356
DoWGState = Any
357
ScaleBySimplifiedAdEMAMixState = Any
358
DifferentiallyPrivateAggregateState = Any
359
360
# Linesearch types
361
class ScaleByBacktrackingLinesearchState(NamedTuple):
362
count: chex.Array
363
f_eval: chex.Array
364
365
class ScaleByZoomLinesearchState(NamedTuple):
366
count: chex.Array
367
f_eval: chex.Array
368
369
class ZoomLinesearchInfo(NamedTuple):
370
failed: bool
371
nfev: int
372
ngev: int
373
k: int
374
```