0
# Optimization
1
2
Pyro optimization utilities for training probabilistic models, including custom optimizers and PyTorch optimizer wrappers for use with Pyro's parameter store.
3
4
## Capabilities
5
6
### Core Optimization Classes
7
8
Base classes and utilities for wrapping PyTorch optimizers to work with Pyro's parameter management system.
9
10
```python { .api }
11
class PyroOptim:
12
"""
13
Base wrapper class for PyTorch optimizers that works with Pyro's parameter store.
14
15
Automatically manages parameter registration and optimization for Pyro models.
16
"""
17
def __init__(self, optim_constructor, optim_args, clip_args=None):
18
"""
19
Parameters:
20
- optim_constructor: PyTorch optimizer constructor
21
- optim_args (dict): Arguments to pass to optimizer
22
- clip_args (dict, optional): Gradient clipping arguments
23
"""
24
25
def __call__(self, params, *args, **kwargs):
26
"""Create optimizer instance for given parameters."""
27
28
class PyroLRScheduler:
29
"""
30
Wrapper for PyTorch learning rate schedulers that works with PyroOptim.
31
"""
32
def __init__(self, scheduler_constructor, optim_args, **kwargs):
33
"""
34
Parameters:
35
- scheduler_constructor: PyTorch scheduler constructor
36
- optim_args (dict): Arguments for scheduler
37
"""
38
```
39
40
### Pyro-Specific Optimizers
41
42
Custom optimizers designed specifically for probabilistic programming use cases.
43
44
```python { .api }
45
class ClippedAdam(PyroOptim):
46
"""
47
Adam optimizer with gradient clipping support.
48
49
Particularly useful for training probabilistic models where gradients
50
can become unstable.
51
"""
52
def __init__(self, optim_args, clip_args=None):
53
"""
54
Parameters:
55
- optim_args (dict): Arguments for Adam optimizer (lr, betas, etc.)
56
- clip_args (dict, optional): Gradient clipping configuration
57
"""
58
59
class AdagradRMSProp(PyroOptim):
60
"""
61
Hybrid optimizer combining Adagrad and RMSprop advantages.
62
63
Designed for sparse gradient scenarios common in probabilistic models.
64
"""
65
def __init__(self, optim_args):
66
"""
67
Parameters:
68
- optim_args (dict): Optimizer arguments (lr, alpha, eps, etc.)
69
"""
70
71
class DCTAdam(PyroOptim):
72
"""
73
Adam optimizer with Discrete Cosine Transform preconditioning.
74
75
Useful for models with structured parameter spaces.
76
"""
77
def __init__(self, optim_args):
78
"""
79
Parameters:
80
- optim_args (dict): Optimizer arguments and DCT configuration
81
"""
82
```
83
84
### PyTorch Optimizer Wrappers
85
86
All standard PyTorch optimizers are available as Pyro-wrapped versions for seamless integration with the parameter store.
87
88
```python { .api }
89
# Standard PyTorch optimizers wrapped for Pyro
90
def Adam(optim_args, clip_args=None):
91
"""Wrapped torch.optim.Adam for Pyro parameter store."""
92
93
def SGD(optim_args, clip_args=None):
94
"""Wrapped torch.optim.SGD for Pyro parameter store."""
95
96
def RMSprop(optim_args, clip_args=None):
97
"""Wrapped torch.optim.RMSprop for Pyro parameter store."""
98
99
def Adagrad(optim_args, clip_args=None):
100
"""Wrapped torch.optim.Adagrad for Pyro parameter store."""
101
102
def AdamW(optim_args, clip_args=None):
103
"""Wrapped torch.optim.AdamW for Pyro parameter store."""
104
105
# And many more PyTorch optimizers...
106
```
107
108
### Distributed Training
109
110
Support for distributed optimization across multiple processes or machines.
111
112
```python { .api }
113
class HorovodOptimizer(PyroOptim):
114
"""
115
Wrapper for Horovod distributed training integration.
116
117
Enables data-parallel training of Pyro models across multiple GPUs/nodes.
118
"""
119
def __init__(self, optim_constructor, optim_args, clip_args=None):
120
"""
121
Parameters:
122
- optim_constructor: Base PyTorch optimizer
123
- optim_args (dict): Optimizer arguments
124
- clip_args (dict, optional): Gradient clipping configuration
125
"""
126
```
127
128
## Examples
129
130
### Basic SVI Setup
131
132
```python
133
import pyro
134
import pyro.distributions as dist
135
from pyro.infer import SVI, Trace_ELBO
136
from pyro.optim import Adam
137
138
def model(data):
139
mu = pyro.param("mu", torch.tensor(0.0))
140
sigma = pyro.param("sigma", torch.tensor(1.0), constraint=dist.constraints.positive)
141
142
with pyro.plate("data", len(data)):
143
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
144
145
def guide(data):
146
pass # Empty guide for maximum likelihood
147
148
# Setup optimizer
149
adam = Adam({"lr": 0.01})
150
svi = SVI(model, guide, adam, Trace_ELBO())
151
152
# Training loop
153
for step in range(1000):
154
loss = svi.step(data)
155
```
156
157
### Custom Learning Rate and Clipping
158
159
```python
160
from pyro.optim import ClippedAdam
161
162
# Adam with gradient clipping
163
clipped_adam = ClippedAdam(
164
optim_args={"lr": 0.01, "betas": (0.90, 0.999)},
165
clip_args={"clip_norm": 10.0}
166
)
167
168
svi = SVI(model, guide, clipped_adam, Trace_ELBO())
169
```
170
171
### Learning Rate Scheduling
172
173
```python
174
from pyro.optim import PyroLRScheduler
175
import torch.optim as optim
176
177
# Setup base optimizer
178
base_optimizer = Adam({"lr": 0.1})
179
180
# Add learning rate scheduler
181
scheduler = PyroLRScheduler(
182
optim.StepLR,
183
{"optimizer": base_optimizer, "step_size": 100, "gamma": 0.1}
184
)
185
186
# Use in SVI
187
svi = SVI(model, guide, scheduler, Trace_ELBO())
188
189
# Training with scheduled learning rate
190
for step in range(1000):
191
loss = svi.step(data)
192
if step % 100 == 0:
193
scheduler.step() # Update learning rate
194
```
195
196
### Multi-Optimizer Setup
197
198
```python
199
from pyro.optim import multi
200
201
# Different optimizers for different parameter groups
202
optimizers = {
203
"mu": Adam({"lr": 0.01}),
204
"sigma": SGD({"lr": 0.001})
205
}
206
207
multi_optim = multi.MultiOptimizer(optimizers)
208
svi = SVI(model, guide, multi_optim, Trace_ELBO())
209
```