0
# Monte Carlo Gradient Estimation
1
2
Utilities for efficient Monte Carlo gradient estimation methods. This module provides various techniques for approximating gradients of expectations, including score function estimators, pathwise estimators, and control variates for variance reduction.
3
4
**Note:** All functions in this module are deprecated and will be removed in Optax version 0.3.0.
5
6
## Capabilities
7
8
### Score Function Gradient Estimation
9
10
#### REINFORCE Estimator
11
12
Estimates gradients using the score function method (REINFORCE). Approximates ∇_θ E_{p(x;θ)} f(x) using E_{p(x;θ)} f(x) ∇_θ log p(x;θ).
13
14
```python { .api }
15
def score_function_jacobians(
16
function,
17
params,
18
dist_builder,
19
rng,
20
num_samples
21
):
22
"""
23
Score function gradient estimation (REINFORCE).
24
25
Args:
26
function: Function f(x) for gradient estimation
27
params: Parameters for constructing the distribution
28
dist_builder: Constructor for building distributions from parameters
29
rng: PRNGKey for random sampling
30
num_samples: Number of samples for gradient computation
31
32
Returns:
33
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
34
"""
35
```
36
37
### Pathwise Gradient Estimation
38
39
#### Reparameterization Trick
40
41
Estimates gradients using the pathwise method (reparameterization trick). Approximates ∇_θ E_{p(x;θ)} f(x) using E_{p(ε)} ∇_θ f(g(ε,θ)) where x = g(ε,θ).
42
43
```python { .api }
44
def pathwise_jacobians(
45
function,
46
params,
47
dist_builder,
48
rng,
49
num_samples
50
):
51
"""
52
Pathwise gradient estimation (reparameterization trick).
53
54
Args:
55
function: Function f(x) for gradient estimation (must be differentiable)
56
params: Parameters for constructing the distribution
57
dist_builder: Constructor for building distributions from parameters
58
rng: PRNGKey for random sampling
59
num_samples: Number of samples for gradient computation
60
61
Returns:
62
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
63
"""
64
```
65
66
### Measure-Valued Gradient Estimation
67
68
#### Measure Difference Method
69
70
Estimates gradients using differences between related measures. Currently only supports Gaussian random variables.
71
72
```python { .api }
73
def measure_valued_jacobians(
74
function,
75
params,
76
dist_builder,
77
rng,
78
num_samples,
79
coupling=True
80
):
81
"""
82
Measure-valued gradient estimation.
83
84
Args:
85
function: Function f(x) for gradient estimation
86
params: Parameters for constructing the distribution
87
dist_builder: Constructor for building distributions from parameters
88
rng: PRNGKey for random sampling
89
num_samples: Number of samples for gradient computation
90
coupling: Whether to use coupling for positive/negative samples (default: True)
91
92
Returns:
93
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
94
"""
95
```
96
97
### Control Variates
98
99
#### Moving Average Baseline
100
101
Implements a moving average baseline for variance reduction.
102
103
```python { .api }
104
def moving_avg_baseline(
105
function,
106
decay=0.99,
107
zero_debias=True,
108
use_decay_early_training_heuristic=True
109
):
110
"""
111
Moving average baseline control variate.
112
113
Args:
114
function: Function for which to compute the control variate
115
decay: Decay rate for the moving average (default: 0.99)
116
zero_debias: Whether to use zero debiasing (default: True)
117
use_decay_early_training_heuristic: Whether to use early training heuristic (default: True)
118
119
Returns:
120
ControlVariate: Tuple of three functions for computing control variate
121
"""
122
```
123
124
#### Control Delta Method
125
126
Implements the control delta covariant method using second-order Taylor expansion.
127
128
```python { .api }
129
def control_delta_method(function):
130
"""
131
Control delta covariant method control variate.
132
133
Args:
134
function: The function for which to compute the control variate
135
136
Returns:
137
ControlVariate: Tuple of three functions for computing control variate
138
"""
139
```
140
141
#### Control Variates with Jacobians
142
143
Combines control variates with gradient estimators for variance reduction.
144
145
```python { .api }
146
def control_variates_jacobians(
147
function,
148
control_variate_from_function,
149
grad_estimator,
150
params,
151
dist_builder,
152
rng,
153
num_samples,
154
control_variate_state=None,
155
estimate_cv_coeffs=False,
156
estimate_cv_coeffs_num_samples=20
157
):
158
"""
159
Gradient estimation using control variates for variance reduction.
160
161
Args:
162
function: Function f(x) for which to estimate gradients
163
control_variate_from_function: The control variate to use
164
grad_estimator: The gradient estimator to compute gradients
165
params: Parameters for constructing the distribution
166
dist_builder: Constructor that builds a distribution from parameters
167
rng: PRNGKey for random sampling
168
num_samples: Number of samples for gradient computation
169
control_variate_state: State of the control variate (optional)
170
estimate_cv_coeffs: Whether to estimate optimal coefficients
171
estimate_cv_coeffs_num_samples: Number of samples for coefficient estimation
172
173
Returns:
174
tuple[Sequence[chex.Array], CvState]: Jacobians and updated control variate state
175
"""
176
```
177
178
## Usage Examples
179
180
```python
181
import optax
182
import jax
183
import jax.numpy as jnp
184
185
# Example: Score function gradient estimation
186
def objective_function(x):
187
return jnp.sum(x**2)
188
189
# Parameters for a Gaussian distribution
190
params = {'mean': jnp.array([1.0, 2.0]), 'log_std': jnp.array([0.0, 0.0])}
191
192
def gaussian_builder(mean, log_std):
193
return tfd.Normal(loc=mean, scale=jnp.exp(log_std))
194
195
rng = jax.random.PRNGKey(42)
196
num_samples = 1000
197
198
# Use score function estimator
199
gradients = optax.monte_carlo.score_function_jacobians(
200
function=objective_function,
201
params=params,
202
dist_builder=gaussian_builder,
203
rng=rng,
204
num_samples=num_samples
205
)
206
207
# Use pathwise estimator (requires differentiable function)
208
gradients_pathwise = optax.monte_carlo.pathwise_jacobians(
209
function=objective_function,
210
params=params,
211
dist_builder=gaussian_builder,
212
rng=rng,
213
num_samples=num_samples
214
)
215
```
216
217
## Gradient Estimation Methods Comparison
218
219
| Method | Function Requirements | Distribution Requirements | Variance |
220
|--------|----------------------|---------------------------|----------|
221
| Score Function | Any | Differentiable log-probability | High |
222
| Pathwise | Differentiable | Reparameterizable | Low |
223
| Measure-valued | Any | Gaussian only | Medium |
224
225
## Import
226
227
```python
228
import optax.monte_carlo
229
# or
230
from optax.monte_carlo import (
231
score_function_jacobians,
232
pathwise_jacobians,
233
measure_valued_jacobians,
234
moving_avg_baseline,
235
control_delta_method,
236
control_variates_jacobians
237
)
238
```
239
240
## Types
241
242
```python { .api }
243
# Control variate types
244
ControlVariate = tuple[
245
Callable, # Control variate computation function
246
Callable, # Expected value function
247
Callable # State update function
248
]
249
250
CvState = Any # Control variate state
251
```