0
# Core Optimizers
1
2
Popular optimization algorithms that are ready for immediate use in training loops. These optimizers combine multiple gradient transformations into complete optimization strategies with sensible defaults.
3
4
## Capabilities
5
6
### Adam Optimizer
7
8
The Adam optimizer with optional Nesterov momentum. Combines adaptive learning rates with momentum for efficient optimization across a wide range of problems.
9
10
```python { .api }
11
def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8, eps_root=0.0, mu_dtype=None, *, nesterov=False):
12
"""
13
Adam optimizer.
14
15
Args:
16
learning_rate: Learning rate or schedule
17
b1: Exponential decay rate for first moment estimates (default: 0.9)
18
b2: Exponential decay rate for second moment estimates (default: 0.999)
19
eps: Small constant for numerical stability (default: 1e-8)
20
eps_root: Small constant for numerical stability in denominator (default: 0.0)
21
mu_dtype: Optional dtype for momentum accumulator (default: None)
22
nesterov: Whether to use Nesterov momentum (default: False)
23
24
Returns:
25
GradientTransformationExtraArgs
26
"""
27
```
28
29
### AdamW Optimizer
30
31
Adam optimizer with decoupled weight decay. Separates weight decay from gradient-based updates for better generalization.
32
33
```python { .api }
34
def adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-8, weight_decay=1e-4, *, nesterov=False):
35
"""
36
AdamW optimizer with decoupled weight decay.
37
38
Args:
39
learning_rate: Learning rate or schedule
40
b1: Exponential decay rate for first moment estimates (default: 0.9)
41
b2: Exponential decay rate for second moment estimates (default: 0.999)
42
eps: Small constant for numerical stability (default: 1e-8)
43
weight_decay: Weight decay coefficient (default: 1e-4)
44
nesterov: Whether to use Nesterov momentum (default: False)
45
46
Returns:
47
GradientTransformation
48
"""
49
```
50
51
### Stochastic Gradient Descent
52
53
Classic SGD optimizer with optional momentum and Nesterov acceleration.
54
55
```python { .api }
56
def sgd(learning_rate, momentum=None, nesterov=False):
57
"""
58
Stochastic gradient descent optimizer.
59
60
Args:
61
learning_rate: Learning rate or schedule
62
momentum: Momentum coefficient (default: None for no momentum)
63
nesterov: Whether to use Nesterov momentum (default: False)
64
65
Returns:
66
GradientTransformation
67
"""
68
```
69
70
### RMSprop Optimizer
71
72
RMSprop optimizer with adaptive learning rates based on recent gradient magnitudes.
73
74
```python { .api }
75
def rmsprop(learning_rate, decay=0.9, eps=1e-8):
76
"""
77
RMSprop optimizer.
78
79
Args:
80
learning_rate: Learning rate or schedule
81
decay: Decay rate for moving average of squared gradients (default: 0.9)
82
eps: Small constant for numerical stability (default: 1e-8)
83
84
Returns:
85
GradientTransformation
86
"""
87
```
88
89
### Adagrad Optimizer
90
91
Adagrad optimizer with adaptive learning rates that decrease over time.
92
93
```python { .api }
94
def adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-7):
95
"""
96
Adagrad optimizer.
97
98
Args:
99
learning_rate: Learning rate or schedule
100
initial_accumulator_value: Initial value for accumulator (default: 0.1)
101
eps: Small constant for numerical stability (default: 1e-7)
102
103
Returns:
104
GradientTransformation
105
"""
106
```
107
108
### Adadelta Optimizer
109
110
Adadelta optimizer that adapts learning rates based on a moving window of gradient updates.
111
112
```python { .api }
113
def adadelta(learning_rate=1.0, rho=0.9, eps=1e-6):
114
"""
115
Adadelta optimizer.
116
117
Args:
118
learning_rate: Learning rate (default: 1.0)
119
rho: Decay rate for moving averages (default: 0.9)
120
eps: Small constant for numerical stability (default: 1e-6)
121
122
Returns:
123
GradientTransformation
124
"""
125
```
126
127
### Adamax Optimizer
128
129
Adamax optimizer, a variant of Adam based on the infinity norm.
130
131
```python { .api }
132
def adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
133
"""
134
Adamax optimizer.
135
136
Args:
137
learning_rate: Learning rate or schedule
138
b1: Exponential decay rate for first moment estimates (default: 0.9)
139
b2: Exponential decay rate for exponentially weighted infinity norm (default: 0.999)
140
eps: Small constant for numerical stability (default: 1e-8)
141
142
Returns:
143
GradientTransformation
144
"""
145
```
146
147
### Nadam Optimizer
148
149
Nesterov-accelerated Adam optimizer combining Adam with Nesterov momentum.
150
151
```python { .api }
152
def nadam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
153
"""
154
Nadam optimizer (Nesterov-accelerated Adam).
155
156
Args:
157
learning_rate: Learning rate or schedule
158
b1: Exponential decay rate for first moment estimates (default: 0.9)
159
b2: Exponential decay rate for second moment estimates (default: 0.999)
160
eps: Small constant for numerical stability (default: 1e-8)
161
162
Returns:
163
GradientTransformation
164
"""
165
```
166
167
### AdaBelief Optimizer
168
169
AdaBelief optimizer that adapts the step size according to the "belief" in the observed gradients.
170
171
```python { .api }
172
def adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16, *, nesterov=False):
173
"""
174
AdaBelief optimizer.
175
176
Args:
177
learning_rate: Learning rate or schedule
178
b1: Exponential decay rate for first moment estimates (default: 0.9)
179
b2: Exponential decay rate for second moment estimates (default: 0.999)
180
eps: Small constant for numerical stability (default: 1e-16)
181
eps_root: Small constant for numerical stability in denominator (default: 1e-16)
182
nesterov: Whether to use Nesterov momentum (default: False)
183
184
Returns:
185
GradientTransformation
186
"""
187
```
188
189
## Usage Example
190
191
```python
192
import optax
193
import jax.numpy as jnp
194
195
# Initialize parameters
196
params = {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros((5,))}
197
198
# Create different optimizers
199
adam_opt = optax.adam(learning_rate=0.001)
200
sgd_opt = optax.sgd(learning_rate=0.01, momentum=0.9)
201
adamw_opt = optax.adamw(learning_rate=0.001, weight_decay=1e-4)
202
203
# Initialize optimizer state
204
adam_state = adam_opt.init(params)
205
sgd_state = sgd_opt.init(params)
206
adamw_state = adamw_opt.init(params)
207
208
# In training loop (example with Adam)
209
def training_step(params, opt_state, gradients):
210
updates, new_opt_state = adam_opt.update(gradients, opt_state)
211
new_params = optax.apply_updates(params, updates)
212
return new_params, new_opt_state
213
```