or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

diagnostics.mddistributions.mdhandlers.mdindex.mdinference.mdoptimization.mdprimitives.mdutilities.md

optimization.mddocs/

0

# Optimization

1

2

NumPyro provides a collection of gradient-based optimizers for parameter learning in variational inference and maximum likelihood estimation. All optimizers are built on JAX for efficient automatic differentiation and support JIT compilation for high-performance optimization.

3

4

## Capabilities

5

6

### Core Optimizer Infrastructure

7

8

Base classes and utilities for the optimization system.

9

10

```python { .api }

11

class Optimizer:

12

"""

13

Base class for optimizers in NumPyro.

14

15

All optimizers follow the same interface pattern for consistency

16

with JAX optimization libraries like optax.

17

"""

18

def init(self, params: dict) -> Any:

19

"""

20

Initialize optimizer state.

21

22

Args:

23

params: Initial parameter values

24

25

Returns:

26

Initial optimizer state

27

"""

28

29

def update(self, grads: dict, state: Any, params: dict) -> tuple:

30

"""

31

Update parameters based on gradients.

32

33

Args:

34

grads: Parameter gradients

35

state: Current optimizer state

36

params: Current parameter values

37

38

Returns:

39

Tuple of (updates, new_state)

40

"""

41

42

def get_params(self, state: Any) -> dict:

43

"""Get current parameter values from optimizer state."""

44

```

45

46

### Adaptive Learning Rate Optimizers

47

48

Optimizers that adapt learning rates based on gradient history.

49

50

```python { .api }

51

class Adam:

52

"""

53

Adaptive Moment Estimation (Adam) optimizer.

54

55

Computes individual adaptive learning rates for different parameters from

56

estimates of first and second moments of the gradients.

57

58

Args:

59

step_size: Learning rate (default: 0.001)

60

b1: Exponential decay rate for first moment estimates (default: 0.9)

61

b2: Exponential decay rate for second moment estimates (default: 0.999)

62

eps: Small constant for numerical stability (default: 1e-8)

63

64

Usage:

65

optimizer = Adam(step_size=0.01)

66

opt_state = optimizer.init(params)

67

68

for step in range(num_steps):

69

grads = compute_gradients(params)

70

updates, opt_state = optimizer.update(grads, opt_state, params)

71

params = apply_updates(params, updates)

72

"""

73

def __init__(self, step_size: float = 0.001, b1: float = 0.9,

74

b2: float = 0.999, eps: float = 1e-8): ...

75

76

class ClippedAdam:

77

"""

78

Adam optimizer with gradient clipping for improved stability.

79

80

Args:

81

step_size: Learning rate

82

b1: First moment decay rate

83

b2: Second moment decay rate

84

eps: Numerical stability constant

85

clip_norm: Maximum gradient norm for clipping

86

87

Usage:

88

# Useful for training on unstable loss landscapes

89

optimizer = ClippedAdam(step_size=0.01, clip_norm=1.0)

90

opt_state = optimizer.init(params)

91

"""

92

def __init__(self, step_size: float = 0.001, b1: float = 0.9,

93

b2: float = 0.999, eps: float = 1e-8, clip_norm: float = 10.0): ...

94

95

class Adagrad:

96

"""

97

Adaptive Gradient Algorithm (Adagrad) optimizer.

98

99

Adapts learning rate to parameters, performing smaller updates for parameters

100

associated with frequently occurring features.

101

102

Args:

103

step_size: Initial learning rate (default: 0.01)

104

eps: Small constant for numerical stability (default: 1e-8)

105

106

Usage:

107

# Good for sparse data and features

108

optimizer = Adagrad(step_size=0.1)

109

opt_state = optimizer.init(params)

110

"""

111

def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...

112

113

class RMSProp:

114

"""

115

Root Mean Square Propagation (RMSProp) optimizer.

116

117

Maintains a moving average of squared gradients to normalize the gradient.

118

119

Args:

120

step_size: Learning rate (default: 0.01)

121

decay: Decay rate for moving average (default: 0.9)

122

eps: Small constant for numerical stability (default: 1e-8)

123

124

Usage:

125

# Good for non-stationary objectives

126

optimizer = RMSProp(step_size=0.01, decay=0.9)

127

opt_state = optimizer.init(params)

128

"""

129

def __init__(self, step_size: float = 0.01, decay: float = 0.9, eps: float = 1e-8): ...

130

131

class RMSPropMomentum:

132

"""

133

RMSProp with momentum for improved convergence.

134

135

Args:

136

step_size: Learning rate

137

decay: Decay rate for squared gradient moving average

138

momentum: Momentum coefficient

139

eps: Numerical stability constant

140

centered: Whether to use centered RMSProp variant

141

142

Usage:

143

# Combines benefits of RMSProp and momentum

144

optimizer = RMSPropMomentum(step_size=0.01, momentum=0.9)

145

opt_state = optimizer.init(params)

146

"""

147

def __init__(self, step_size: float = 0.01, decay: float = 0.9,

148

momentum: float = 0.0, eps: float = 1e-8, centered: bool = False): ...

149

```

150

151

### Momentum-Based Optimizers

152

153

Optimizers that use momentum to accelerate convergence.

154

155

```python { .api }

156

class SGD:

157

"""

158

Stochastic Gradient Descent optimizer.

159

160

Basic gradient descent with optional momentum.

161

162

Args:

163

step_size: Learning rate (default: 0.01)

164

momentum: Momentum coefficient (default: 0.0)

165

166

Usage:

167

# Simple gradient descent

168

optimizer = SGD(step_size=0.01)

169

170

# With momentum for faster convergence

171

optimizer = SGD(step_size=0.01, momentum=0.9)

172

opt_state = optimizer.init(params)

173

"""

174

def __init__(self, step_size: float = 0.01, momentum: float = 0.0): ...

175

176

class Momentum:

177

"""

178

Stochastic Gradient Descent with momentum.

179

180

Accelerates gradient descent by accumulating a velocity vector in directions

181

of persistent reduction in the objective function.

182

183

Args:

184

step_size: Learning rate (default: 0.01)

185

mass: Momentum coefficient (default: 0.9)

186

187

Usage:

188

# Classical momentum SGD

189

optimizer = Momentum(step_size=0.01, mass=0.9)

190

opt_state = optimizer.init(params)

191

"""

192

def __init__(self, step_size: float = 0.01, mass: float = 0.9): ...

193

```

194

195

### Specialized Optimizers

196

197

Advanced optimizers for specific use cases.

198

199

```python { .api }

200

class SM3:

201

"""

202

Square-root of second Moment (SM3) optimizer.

203

204

Memory-efficient adaptive optimizer that maintains a single accumulator

205

per parameter instead of separate first and second moment estimates.

206

207

Args:

208

step_size: Learning rate (default: 0.01)

209

eps: Small constant for numerical stability (default: 1e-8)

210

211

Usage:

212

# Memory-efficient alternative to Adam for large models

213

optimizer = SM3(step_size=0.01)

214

opt_state = optimizer.init(params)

215

"""

216

def __init__(self, step_size: float = 0.01, eps: float = 1e-8): ...

217

218

class Minimize:

219

"""

220

Wrapper for JAX's minimize function for direct optimization.

221

222

Uses JAX's built-in optimization routines like L-BFGS for direct

223

minimization of objective functions.

224

225

Args:

226

method: Optimization method ('BFGS', 'L-BFGS-B', 'CG', etc.)

227

options: Additional options for the underlying scipy optimizer

228

229

Usage:

230

# For objectives where full optimization is preferred over SGD

231

optimizer = Minimize(method='L-BFGS-B')

232

233

# Direct minimization (different interface)

234

result = optimizer.minimize(loss_fn, init_params)

235

"""

236

def __init__(self, method: str = 'BFGS', options: Optional[dict] = None): ...

237

238

def minimize(self, fun: Callable, x0: dict, *args, **kwargs) -> dict:

239

"""

240

Minimize objective function.

241

242

Args:

243

fun: Objective function to minimize

244

x0: Initial parameter values

245

*args: Additional arguments to objective function

246

**kwargs: Additional keyword arguments

247

248

Returns:

249

Optimization result with final parameters and metadata

250

"""

251

```

252

253

### Optimizer Utilities

254

255

Utility functions for working with optimizers and optimization schedules.

256

257

```python { .api }

258

def multi_transform(transforms: dict, param_labels: dict) -> Optimizer:

259

"""

260

Apply different optimizers to different parameter groups.

261

262

Args:

263

transforms: Dictionary mapping labels to optimizers

264

param_labels: Dictionary mapping parameter names to labels

265

266

Returns:

267

Combined optimizer that applies appropriate transform to each parameter group

268

269

Usage:

270

# Different learning rates for different parameter groups

271

transforms = {

272

'weights': Adam(0.01),

273

'biases': Adam(0.1)

274

}

275

param_labels = {

276

'layer1.weight': 'weights',

277

'layer1.bias': 'biases'

278

}

279

optimizer = multi_transform(transforms, param_labels)

280

"""

281

282

def exponential_decay(step_size: float, decay_steps: int,

283

decay_rate: float, staircase: bool = False) -> Callable:

284

"""

285

Create exponential learning rate decay schedule.

286

287

Args:

288

step_size: Initial learning rate

289

decay_steps: Number of steps after which to apply decay

290

decay_rate: Decay factor

291

staircase: Whether to apply decay in discrete steps

292

293

Returns:

294

Learning rate schedule function

295

296

Usage:

297

schedule = exponential_decay(0.1, decay_steps=1000, decay_rate=0.96)

298

optimizer = Adam(step_size=schedule)

299

"""

300

301

def polynomial_decay(step_size: float, transition_steps: int,

302

transition_begin: int = 0, power: float = 1.0,

303

end_value: float = 0.0) -> Callable:

304

"""

305

Create polynomial learning rate decay schedule.

306

307

Args:

308

step_size: Initial learning rate

309

transition_steps: Number of steps over which to decay

310

transition_begin: Step at which to begin decay

311

power: Power of polynomial decay

312

end_value: Final learning rate value

313

314

Returns:

315

Learning rate schedule function

316

"""

317

318

def warmup_schedule(warmup_steps: int, peak_value: float,

319

end_value: float = 0.0) -> Callable:

320

"""

321

Create learning rate warmup schedule.

322

323

Args:

324

warmup_steps: Number of warmup steps

325

peak_value: Peak learning rate after warmup

326

end_value: Final learning rate value

327

328

Returns:

329

Learning rate schedule function

330

331

Usage:

332

# Linear warmup to peak, then decay

333

schedule = warmup_schedule(1000, peak_value=0.01)

334

optimizer = Adam(step_size=schedule)

335

"""

336

```

337

338

### Integration with SVI

339

340

Examples of how optimizers integrate with Stochastic Variational Inference.

341

342

```python { .api }

343

# Usage with SVI

344

from numpyro.infer import SVI, Trace_ELBO

345

346

def example_svi_usage():

347

"""Example of using optimizers with SVI."""

348

349

# Define model and guide

350

def model(data):

351

mu = numpyro.sample("mu", dist.Normal(0, 1))

352

with numpyro.plate("data", len(data)):

353

numpyro.sample("obs", dist.Normal(mu, 1), obs=data)

354

355

def guide(data):

356

mu_loc = numpyro.param("mu_loc", 0.0)

357

mu_scale = numpyro.param("mu_scale", 1.0, constraint=constraints.positive)

358

numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))

359

360

# Various optimizer configurations

361

optimizers = {

362

# Basic Adam

363

'adam': Adam(0.01),

364

365

# Adam with gradient clipping

366

'clipped_adam': ClippedAdam(0.01, clip_norm=1.0),

367

368

# RMSProp for non-stationary problems

369

'rmsprop': RMSProp(0.01, decay=0.9),

370

371

# SGD with momentum

372

'sgd_momentum': SGD(0.01, momentum=0.9),

373

374

# Different rates for different parameters

375

'multi_rate': multi_transform({

376

'loc': Adam(0.01),

377

'scale': Adam(0.001)

378

}, {

379

'mu_loc': 'loc',

380

'mu_scale': 'scale'

381

})

382

}

383

384

# Run SVI with chosen optimizer

385

optimizer = optimizers['adam']

386

svi = SVI(model, guide, optimizer, Trace_ELBO())

387

388

# Training loop

389

svi_result = svi.run(random.PRNGKey(0), 1000, data)

390

391

return svi_result

392

```

393

394

## Usage Examples

395

396

```python

397

import numpyro

398

import numpyro.distributions as dist

399

from numpyro.infer import SVI, Trace_ELBO

400

from numpyro.optim import Adam, RMSProp, SGD

401

import jax.numpy as jnp

402

from jax import random

403

404

# Basic optimizer usage

405

def simple_optimization_example():

406

# Define simple model

407

def model(x, y):

408

a = numpyro.sample("a", dist.Normal(0, 1))

409

b = numpyro.sample("b", dist.Normal(0, 1))

410

mu = a * x + b

411

numpyro.sample("y", dist.Normal(mu, 0.1), obs=y)

412

413

def guide(x, y):

414

a_loc = numpyro.param("a_loc", 0.0)

415

a_scale = numpyro.param("a_scale", 1.0, constraint=constraints.positive)

416

b_loc = numpyro.param("b_loc", 0.0)

417

b_scale = numpyro.param("b_scale", 1.0, constraint=constraints.positive)

418

419

numpyro.sample("a", dist.Normal(a_loc, a_scale))

420

numpyro.sample("b", dist.Normal(b_loc, b_scale))

421

422

# Generate synthetic data

423

true_a, true_b = 2.0, 1.0

424

x = jnp.linspace(0, 1, 100)

425

y = true_a * x + true_b + 0.1 * random.normal(random.PRNGKey(0), (100,))

426

427

# Compare different optimizers

428

optimizers = {

429

'Adam': Adam(0.01),

430

'RMSProp': RMSProp(0.01),

431

'SGD': SGD(0.01, momentum=0.9)

432

}

433

434

results = {}

435

for name, optimizer in optimizers.items():

436

svi = SVI(model, guide, optimizer, Trace_ELBO())

437

svi_result = svi.run(random.PRNGKey(1), 1000, x, y)

438

results[name] = svi_result

439

440

# Print final loss

441

print(f"{name} final loss: {svi_result.losses[-1]:.4f}")

442

443

return results

444

445

# Advanced optimizer configuration

446

def advanced_optimization_example():

447

# Complex model with multiple parameter groups

448

def hierarchical_model(group_idx, y):

449

# Global parameters

450

mu_global = numpyro.sample("mu_global", dist.Normal(0, 10))

451

sigma_global = numpyro.sample("sigma_global", dist.Exponential(1))

452

453

# Group parameters

454

n_groups = len(jnp.unique(group_idx))

455

with numpyro.plate("groups", n_groups):

456

mu_group = numpyro.sample("mu_group", dist.Normal(mu_global, sigma_global))

457

458

# Observations

459

with numpyro.plate("data", len(y)):

460

mu = mu_group[group_idx]

461

numpyro.sample("y", dist.Normal(mu, 1), obs=y)

462

463

def hierarchical_guide(group_idx, y):

464

# Global parameter variational families

465

mu_global_loc = numpyro.param("mu_global_loc", 0.0)

466

mu_global_scale = numpyro.param("mu_global_scale", 1.0, constraint=constraints.positive)

467

sigma_global_rate = numpyro.param("sigma_global_rate", 1.0, constraint=constraints.positive)

468

469

# Group parameter variational families

470

n_groups = len(jnp.unique(group_idx))

471

mu_group_loc = numpyro.param("mu_group_loc", jnp.zeros(n_groups))

472

mu_group_scale = numpyro.param("mu_group_scale", jnp.ones(n_groups), constraint=constraints.positive)

473

474

# Sample from variational distributions

475

numpyro.sample("mu_global", dist.Normal(mu_global_loc, mu_global_scale))

476

numpyro.sample("sigma_global", dist.Exponential(sigma_global_rate))

477

478

with numpyro.plate("groups", n_groups):

479

numpyro.sample("mu_group", dist.Normal(mu_group_loc, mu_group_scale))

480

481

# Multi-rate optimization: different learning rates for global vs group parameters

482

optimizer = multi_transform({

483

'global': Adam(0.01), # Slower for global parameters

484

'group': Adam(0.05) # Faster for group parameters

485

}, {

486

'mu_global_loc': 'global',

487

'mu_global_scale': 'global',

488

'sigma_global_rate': 'global',

489

'mu_group_loc': 'group',

490

'mu_group_scale': 'group'

491

})

492

493

# Learning rate schedule

494

schedule = exponential_decay(step_size=0.01, decay_steps=500, decay_rate=0.96)

495

scheduled_optimizer = Adam(step_size=schedule)

496

497

return optimizer, scheduled_optimizer

498

```

499

500

## Types

501

502

```python { .api }

503

from typing import Optional, Union, Callable, Dict, Any, Tuple

504

from jax import Array

505

import jax.numpy as jnp

506

507

ArrayLike = Union[Array, jnp.ndarray, float, int]

508

Params = Dict[str, ArrayLike]

509

Grads = Dict[str, ArrayLike]

510

Updates = Dict[str, ArrayLike]

511

OptState = Any # Optimizer-specific state type

512

513

class OptimizerState:

514

"""Base optimizer state interface."""

515

step: int

516

params: Params

517

518

class AdamState(OptimizerState):

519

"""State for Adam optimizer."""

520

step: int

521

params: Params

522

m: Params # First moment estimates

523

v: Params # Second moment estimates

524

525

class SGDState(OptimizerState):

526

"""State for SGD optimizer."""

527

step: int

528

params: Params

529

momentum: Optional[Params] # Momentum terms

530

531

class RMSPropState(OptimizerState):

532

"""State for RMSProp optimizer."""

533

step: int

534

params: Params

535

v: Params # Squared gradient moving average

536

537

# Optimizer interface

538

class OptimizerProtocol:

539

"""Protocol for NumPyro optimizers."""

540

def init(self, params: Params) -> OptState: ...

541

def update(self, grads: Grads, state: OptState, params: Params) -> Tuple[Updates, OptState]: ...

542

def get_params(self, state: OptState) -> Params: ...

543

544

# Schedule functions

545

ScheduleFunction = Callable[[int], float]

546

547

# Optimizer factory functions

548

OptimizerFactory = Callable[..., OptimizerProtocol]

549

```