0
# Optimization
1
2
NumPyro provides a collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation. All optimizers are built on JAX for efficient automatic differentiation and support JIT compilation for high-performance optimization.
3
4
## Capabilities
5
6
### Core Optimizer Infrastructure
7
8
Base classes and utilities for the optimization system.
9
10
```python { .api }
11
class Optimizer:
12
"""
13
Base class for optimizers in NumPyro.
14
15
All optimizers follow the same interface pattern for consistency
16
with JAX optimization libraries like optax.
17
"""
18
def init(self, params: dict) -> Any:
19
"""
20
Initialize optimizer state.
21
22
Args:
23
params: Initial parameter values
24
25
Returns:
26
Initial optimizer state
27
"""
28
29
def update(self, grads: dict, state: Any, params: dict) -> tuple:
30
"""
31
Update parameters based on gradients.
32
33
Args:
34
grads: Parameter gradients
35
state: Current optimizer state
36
params: Current parameter values
37
38
Returns:
39
Tuple of (updates, new_state)
40
"""
41
42
def get_params(self, state: Any) -> dict:
43
"""Get current parameter values from optimizer state."""
44
```
45
46
### Adaptive Learning Rate Optimizers
47
48
Optimizers that adapt learning rates based on gradient history.
49
50
```python { .api }
51
class Adam:
52
"""
53
Adaptive Moment Estimation (Adam) optimizer.
54
55
Computes individual adaptive learning rates for different parameters from
56
estimates of first and second moments of the gradients.
57
58
Args:
59
step_size: Learning rate (default: 0.001)
60
b1: Exponential decay rate for first moment estimates (default: 0.9)
61
b2: Exponential decay rate for second moment estimates (default: 0.999)
62
eps: Small constant for numerical stability (default: 1e-8)
63
64
Usage:
65
optimizer = Adam(step_size=0.01)
66
opt_state = optimizer.init(params)
67
68
for step in range(num_steps):
69
grads = compute_gradients(params)
70
updates, opt_state = optimizer.update(grads, opt_state, params)
71
params = apply_updates(params, updates)
72
"""
73
def __init__(self, step_size: float = 0.001, b1: float = 0.9,
74
b2: float = 0.999, eps: float = 1e-8): ...
75
76
class ClippedAdam:
77
"""
78
Adam optimizer with gradient clipping for improved stability.
79
80
Args:
81
step_size: Learning rate
82
b1: First moment decay rate
83
b2: Second moment decay rate
84
eps: Numerical stability constant
85
clip_norm: Maximum gradient norm for clipping
86
87
Usage:
88
# Useful for training on unstable loss landscapes
89
optimizer = ClippedAdam(step_size=0.01, clip_norm=1.0)
90
opt_state = optimizer.init(params)
91
"""
92
def __init__(self, step_size: float = 0.001, b1: float = 0.9,
93
b2: float = 0.999, eps: float = 1e-8, clip_norm: float = 10.0): ...
94
95
class Adagrad:
96
"""
97
Adaptive Gradient Algorithm (Adagrad) optimizer.
98
99
Adapts learning rate to parameters, performing smaller updates for parameters
100
associated with frequently occurring features.
101
102
Args:
103
step_size: Initial learning rate (default: 0.01)
104
eps: Small constant for numerical stability (default: 1e-8)
105
106
Usage:
107
# Good for sparse data and features
108
optimizer = Adagrad(step_size=0.1)
109
opt_state = optimizer.init(params)
110
"""
111
def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...
112
113
class RMSProp:
114
"""
115
Root Mean Square Propagation (RMSProp) optimizer.
116
117
Maintains a moving average of squared gradients to normalize the gradient.
118
119
Args:
120
step_size: Learning rate (default: 0.01)
121
decay: Decay rate for moving average (default: 0.9)
122
eps: Small constant for numerical stability (default: 1e-8)
123
124
Usage:
125
# Good for non-stationary objectives
126
optimizer = RMSProp(step_size=0.01, decay=0.9)
127
opt_state = optimizer.init(params)
128
"""
129
def __init__(self, step_size: float = 0.01, decay: float = 0.9, eps: float = 1e-8): ...
130
131
class RMSPropMomentum:
132
"""
133
RMSProp with momentum for improved convergence.
134
135
Args:
136
step_size: Learning rate
137
decay: Decay rate for squared gradient moving average
138
momentum: Momentum coefficient
139
eps: Numerical stability constant
140
centered: Whether to use centered RMSProp variant
141
142
Usage:
143
# Combines benefits of RMSProp and momentum
144
optimizer = RMSPropMomentum(step_size=0.01, momentum=0.9)
145
opt_state = optimizer.init(params)
146
"""
147
def __init__(self, step_size: float = 0.01, decay: float = 0.9,
148
momentum: float = 0.0, eps: float = 1e-8, centered: bool = False): ...
149
```
150
151
### Momentum-Based Optimizers
152
153
Optimizers that use momentum to accelerate convergence.
154
155
```python { .api }
156
class SGD:
157
"""
158
Stochastic Gradient Descent optimizer.
159
160
Basic gradient descent with optional momentum.
161
162
Args:
163
step_size: Learning rate (default: 0.01)
164
momentum: Momentum coefficient (default: 0.0)
165
166
Usage:
167
# Simple gradient descent
168
optimizer = SGD(step_size=0.01)
169
170
# With momentum for faster convergence
171
optimizer = SGD(step_size=0.01, momentum=0.9)
172
opt_state = optimizer.init(params)
173
"""
174
def __init__(self, step_size: float = 0.01, momentum: float = 0.0): ...
175
176
class Momentum:
177
"""
178
Stochastic Gradient Descent with momentum.
179
180
Accelerates gradient descent by accumulating a velocity vector in directions
181
of persistent reduction in the objective function.
182
183
Args:
184
step_size: Learning rate (default: 0.01)
185
mass: Momentum coefficient (default: 0.9)
186
187
Usage:
188
# Classical momentum SGD
189
optimizer = Momentum(step_size=0.01, mass=0.9)
190
opt_state = optimizer.init(params)
191
"""
192
def __init__(self, step_size: float = 0.01, mass: float = 0.9): ...
193
```
194
195
### Specialized Optimizers
196
197
Advanced optimizers for specific use cases.
198
199
```python { .api }
200
class SM3:
201
"""
202
Square-root of second Moment (SM3) optimizer.
203
204
Memory-efficient adaptive optimizer that maintains a single accumulator
205
per parameter instead of separate first and second moment estimates.
206
207
Args:
208
step_size: Learning rate (default: 0.01)
209
eps: Small constant for numerical stability (default: 1e-8)
210
211
Usage:
212
# Memory-efficient alternative to Adam for large models
213
optimizer = SM3(step_size=0.01)
214
opt_state = optimizer.init(params)
215
"""
216
def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...
217
218
class Minimize:
219
"""
220
Wrapper for JAX's minimize function for direct optimization.
221
222
Uses JAX's built-in optimization routines like L-BFGS for direct
223
minimization of objective functions.
224
225
Args:
226
method: Optimization method ('BFGS', 'L-BFGS-B', 'CG', etc.)
227
options: Additional options for the underlying scipy optimizer
228
229
Usage:
230
# For objectives where full optimization is preferred over SGD
231
optimizer = Minimize(method='L-BFGS-B')
232
233
# Direct minimization (different interface)
234
result = optimizer.minimize(loss_fn, init_params)
235
"""
236
def __init__(self, method: str = 'BFGS', options: Optional[dict] = None): ...
237
238
def minimize(self, fun: Callable, x0: dict, *args, **kwargs) -> dict:
239
"""
240
Minimize objective function.
241
242
Args:
243
fun: Objective function to minimize
244
x0: Initial parameter values
245
*args: Additional arguments to objective function
246
**kwargs: Additional keyword arguments
247
248
Returns:
249
Optimization result with final parameters and metadata
250
"""
251
```
252
253
### Optimizer Utilities
254
255
Utility functions for working with optimizers and optimization schedules.
256
257
```python { .api }
258
def multi_transform(transforms: dict, param_labels: dict) -> Optimizer:
259
"""
260
Apply different optimizers to different parameter groups.
261
262
Args:
263
transforms: Dictionary mapping labels to optimizers
264
param_labels: Dictionary mapping parameter names to labels
265
266
Returns:
267
Combined optimizer that applies appropriate transform to each parameter group
268
269
Usage:
270
# Different learning rates for different parameter groups
271
transforms = {
272
'weights': Adam(0.01),
273
'biases': Adam(0.1)
274
}
275
param_labels = {
276
'layer1.weight': 'weights',
277
'layer1.bias': 'biases'
278
}
279
optimizer = multi_transform(transforms, param_labels)
280
"""
281
282
def exponential_decay(step_size: float, decay_steps: int,
283
decay_rate: float, staircase: bool = False) -> Callable:
284
"""
285
Create exponential learning rate decay schedule.
286
287
Args:
288
step_size: Initial learning rate
289
decay_steps: Number of steps after which to apply decay
290
decay_rate: Decay factor
291
staircase: Whether to apply decay in discrete steps
292
293
Returns:
294
Learning rate schedule function
295
296
Usage:
297
schedule = exponential_decay(0.1, decay_steps=1000, decay_rate=0.96)
298
optimizer = Adam(step_size=schedule)
299
"""
300
301
def polynomial_decay(step_size: float, transition_steps: int,
302
transition_begin: int = 0, power: float = 1.0,
303
end_value: float = 0.0) -> Callable:
304
"""
305
Create polynomial learning rate decay schedule.
306
307
Args:
308
step_size: Initial learning rate
309
transition_steps: Number of steps over which to decay
310
transition_begin: Step at which to begin decay
311
power: Power of polynomial decay
312
end_value: Final learning rate value
313
314
Returns:
315
Learning rate schedule function
316
"""
317
318
def warmup_schedule(warmup_steps: int, peak_value: float,
319
end_value: float = 0.0) -> Callable:
320
"""
321
Create learning rate warmup schedule.
322
323
Args:
324
warmup_steps: Number of warmup steps
325
peak_value: Peak learning rate after warmup
326
end_value: Final learning rate value
327
328
Returns:
329
Learning rate schedule function
330
331
Usage:
332
# Linear warmup to peak, then decay
333
schedule = warmup_schedule(1000, peak_value=0.01)
334
optimizer = Adam(step_size=schedule)
335
"""
336
```
337
338
### Integration with SVI
339
340
Examples of how optimizers integrate with Stochastic Variational Inference.
341
342
```python { .api }
343
# Usage with SVI
344
from numpyro.infer import SVI, Trace_ELBO
345
346
def example_svi_usage():
347
"""Example of using optimizers with SVI."""
348
349
# Define model and guide
350
def model(data):
351
mu = numpyro.sample("mu", dist.Normal(0, 1))
352
with numpyro.plate("data", len(data)):
353
numpyro.sample("obs", dist.Normal(mu, 1), obs=data)
354
355
def guide(data):
356
mu_loc = numpyro.param("mu_loc", 0.0)
357
mu_scale = numpyro.param("mu_scale", 1.0, constraint=constraints.positive)
358
numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
359
360
# Various optimizer configurations
361
optimizers = {
362
# Basic Adam
363
'adam': Adam(0.01),
364
365
# Adam with gradient clipping
366
'clipped_adam': ClippedAdam(0.01, clip_norm=1.0),
367
368
# RMSProp for non-stationary problems
369
'rmsprop': RMSProp(0.01, decay=0.9),
370
371
# SGD with momentum
372
'sgd_momentum': SGD(0.01, momentum=0.9),
373
374
# Different rates for different parameters
375
'multi_rate': multi_transform({
376
'loc': Adam(0.01),
377
'scale': Adam(0.001)
378
}, {
379
'mu_loc': 'loc',
380
'mu_scale': 'scale'
381
})
382
}
383
384
# Run SVI with chosen optimizer
385
optimizer = optimizers['adam']
386
svi = SVI(model, guide, optimizer, Trace_ELBO())
387
388
# Training loop
389
svi_result = svi.run(random.PRNGKey(0), 1000, data)
390
391
return svi_result
392
```
393
394
## Usage Examples
395
396
```python
397
import numpyro
398
import numpyro.distributions as dist
399
from numpyro.infer import SVI, Trace_ELBO
400
from numpyro.optim import Adam, RMSProp, SGD
401
import jax.numpy as jnp
402
from jax import random
403
404
# Basic optimizer usage
405
def simple_optimization_example():
406
# Define simple model
407
def model(x, y):
408
a = numpyro.sample("a", dist.Normal(0, 1))
409
b = numpyro.sample("b", dist.Normal(0, 1))
410
mu = a * x + b
411
numpyro.sample("y", dist.Normal(mu, 0.1), obs=y)
412
413
def guide(x, y):
414
a_loc = numpyro.param("a_loc", 0.0)
415
a_scale = numpyro.param("a_scale", 1.0, constraint=constraints.positive)
416
b_loc = numpyro.param("b_loc", 0.0)
417
b_scale = numpyro.param("b_scale", 1.0, constraint=constraints.positive)
418
419
numpyro.sample("a", dist.Normal(a_loc, a_scale))
420
numpyro.sample("b", dist.Normal(b_loc, b_scale))
421
422
# Generate synthetic data
423
true_a, true_b = 2.0, 1.0
424
x = jnp.linspace(0, 1, 100)
425
y = true_a * x + true_b + 0.1 * random.normal(random.PRNGKey(0), (100,))
426
427
# Compare different optimizers
428
optimizers = {
429
'Adam': Adam(0.01),
430
'RMSProp': RMSProp(0.01),
431
'SGD': SGD(0.01, momentum=0.9)
432
}
433
434
results = {}
435
for name, optimizer in optimizers.items():
436
svi = SVI(model, guide, optimizer, Trace_ELBO())
437
svi_result = svi.run(random.PRNGKey(1), 1000, x, y)
438
results[name] = svi_result
439
440
# Print final loss
441
print(f"{name} final loss: {svi_result.losses[-1]:.4f}")
442
443
return results
444
445
# Advanced optimizer configuration
446
def advanced_optimization_example():
447
# Complex model with multiple parameter groups
448
def hierarchical_model(group_idx, y):
449
# Global parameters
450
mu_global = numpyro.sample("mu_global", dist.Normal(0, 10))
451
sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))
452
453
# Group parameters
454
n_groups = len(jnp.unique(group_idx))
455
with numpyro.plate("groups", n_groups):
456
mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))
457
458
# Observations
459
with numpyro.plate("data", len(y)):
460
mu = mu_group[group_idx]
461
numpyro.sample("y", dist.Normal(mu, 1), obs=y)
462
463
def hierarchical_guide(group_idx, y):
464
# Global parameter variational families
465
mu_global_loc = numpyro.param("mu_global_loc", 0.0)
466
mu_global_scale = numpyro.param("mu_global_scale", 1.0, constraint=constraints.positive)
467
sigma_global_rate = numpyro.param("sigma_global_rate", 1.0, constraint=constraints.positive)
468
469
# Group parameter variational families
470
n_groups = len(jnp.unique(group_idx))
471
mu_group_loc = numpyro.param("mu_group_loc", jnp.zeros(n_groups))
472
mu_group_scale = numpyro.param("mu_group_scale", jnp.ones(n_groups), constraint=constraints.positive)
473
474
# Sample from variational distributions
475
numpyro.sample("mu_global", dist.Normal(mu_global_loc, mu_global_scale))
476
numpyro.sample("sigma_global", dist.Exponential(sigma_global_rate))
477
478
with numpyro.plate("groups", n_groups):
479
numpyro.sample("mu_group", dist.Normal(mu_group_loc, mu_group_scale))
480
481
# Multi-rate optimization: different learning rates for global vs group parameters
482
optimizer = multi_transform({
483
'global': Adam(0.01), # Slower for global parameters
484
'group': Adam(0.05) # Faster for group parameters
485
}, {
486
'mu_global_loc': 'global',
487
'mu_global_scale': 'global',
488
'sigma_global_rate': 'global',
489
'mu_group_loc': 'group',
490
'mu_group_scale': 'group'
491
})
492
493
# Learning rate schedule
494
schedule = exponential_decay(step_size=0.01, decay_steps=500, decay_rate=0.96)
495
scheduled_optimizer = Adam(step_size=schedule)
496
497
return optimizer, scheduled_optimizer
498
```
499
500
## Types
501
502
```python { .api }
503
from typing import Optional, Union, Callable, Dict, Any, Tuple
504
from jax import Array
505
import jax.numpy as jnp
506
507
ArrayLike = Union[Array, jnp.ndarray, float, int]
508
Params = Dict[str, ArrayLike]
509
Grads = Dict[str, ArrayLike]
510
Updates = Dict[str, ArrayLike]
511
OptState = Any # Optimizer-specific state type
512
513
class OptimizerState:
514
"""Base optimizer state interface."""
515
step: int
516
params: Params
517
518
class AdamState(OptimizerState):
519
"""State for Adam optimizer."""
520
step: int
521
params: Params
522
m: Params # First moment estimates
523
v: Params # Second moment estimates
524
525
class SGDState(OptimizerState):
526
"""State for SGD optimizer."""
527
step: int
528
params: Params
529
momentum: Optional[Params] # Momentum terms
530
531
class RMSPropState(OptimizerState):
532
"""State for RMSProp optimizer."""
533
step: int
534
params: Params
535
v: Params # Squared gradient moving average
536
537
# Optimizer interface
538
class OptimizerProtocol:
539
"""Protocol for NumPyro optimizers."""
540
def init(self, params: Params) -> OptState: ...
541
def update(self, grads: Grads, state: OptState, params: Params) -> Tuple[Updates, OptState]: ...
542
def get_params(self, state: OptState) -> Params: ...
543
544
# Schedule functions
545
ScheduleFunction = Callable[[int], float]
546
547
# Optimizer factory functions
548
OptimizerFactory = Callable[..., OptimizerProtocol]
549
```