or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

monte-carlo.mddocs/

0

# Monte Carlo Gradient Estimation

1

2

Utilities for efficient Monte Carlo gradient estimation methods. This module provides various techniques for approximating gradients of expectations, including score function estimators, pathwise estimators, and control variates for variance reduction.

3

4

**Note:** All functions in this module are deprecated and will be removed in Optax version 0.3.0.

5

6

## Capabilities

7

8

### Score Function Gradient Estimation

9

10

#### REINFORCE Estimator

11

12

Estimates gradients using the score function method (REINFORCE). Approximates ∇_θ E_{p(x;θ)} f(x) using E_{p(x;θ)} f(x) ∇_θ log p(x;θ).

13

14

```python { .api }

15

def score_function_jacobians(

16

function,

17

params,

18

dist_builder,

19

rng,

20

num_samples

21

):

22

"""

23

Score function gradient estimation (REINFORCE).

24

25

Args:

26

function: Function f(x) for gradient estimation

27

params: Parameters for constructing the distribution

28

dist_builder: Constructor for building distributions from parameters

29

rng: PRNGKey for random sampling

30

num_samples: Number of samples for gradient computation

31

32

Returns:

33

Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape

34

"""

35

```

36

37

### Pathwise Gradient Estimation

38

39

#### Reparameterization Trick

40

41

Estimates gradients using the pathwise method (reparameterization trick). Approximates ∇_θ E_{p(x;θ)} f(x) using E_{p(ε)} ∇_θ f(g(ε,θ)) where x = g(ε,θ).

42

43

```python { .api }

44

def pathwise_jacobians(

45

function,

46

params,

47

dist_builder,

48

rng,

49

num_samples

50

):

51

"""

52

Pathwise gradient estimation (reparameterization trick).

53

54

Args:

55

function: Function f(x) for gradient estimation (must be differentiable)

56

params: Parameters for constructing the distribution

57

dist_builder: Constructor for building distributions from parameters

58

rng: PRNGKey for random sampling

59

num_samples: Number of samples for gradient computation

60

61

Returns:

62

Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape

63

"""

64

```

65

66

### Measure-Valued Gradient Estimation

67

68

#### Measure Difference Method

69

70

Estimates gradients using differences between related measures. Currently only supports Gaussian random variables.

71

72

```python { .api }

73

def measure_valued_jacobians(

74

function,

75

params,

76

dist_builder,

77

rng,

78

num_samples,

79

coupling=True

80

):

81

"""

82

Measure-valued gradient estimation.

83

84

Args:

85

function: Function f(x) for gradient estimation

86

params: Parameters for constructing the distribution

87

dist_builder: Constructor for building distributions from parameters

88

rng: PRNGKey for random sampling

89

num_samples: Number of samples for gradient computation

90

coupling: Whether to use coupling for positive/negative samples (default: True)

91

92

Returns:

93

Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape

94

"""

95

```

96

97

### Control Variates

98

99

#### Moving Average Baseline

100

101

Implements a moving average baseline for variance reduction.

102

103

```python { .api }

104

def moving_avg_baseline(

105

function,

106

decay=0.99,

107

zero_debias=True,

108

use_decay_early_training_heuristic=True

109

):

110

"""

111

Moving average baseline control variate.

112

113

Args:

114

function: Function for which to compute the control variate

115

decay: Decay rate for the moving average (default: 0.99)

116

zero_debias: Whether to use zero debiasing (default: True)

117

use_decay_early_training_heuristic: Whether to use early training heuristic (default: True)

118

119

Returns:

120

ControlVariate: Tuple of three functions for computing control variate

121

"""

122

```

123

124

#### Control Delta Method

125

126

Implements the control delta covariant method using second-order Taylor expansion.

127

128

```python { .api }

129

def control_delta_method(function):

130

"""

131

Control delta covariant method control variate.

132

133

Args:

134

function: The function for which to compute the control variate

135

136

Returns:

137

ControlVariate: Tuple of three functions for computing control variate

138

"""

139

```

140

141

#### Control Variates with Jacobians

142

143

Combines control variates with gradient estimators for variance reduction.

144

145

```python { .api }

146

def control_variates_jacobians(

147

function,

148

control_variate_from_function,

149

grad_estimator,

150

params,

151

dist_builder,

152

rng,

153

num_samples,

154

control_variate_state=None,

155

estimate_cv_coeffs=False,

156

estimate_cv_coeffs_num_samples=20

157

):

158

"""

159

Gradient estimation using control variates for variance reduction.

160

161

Args:

162

function: Function f(x) for which to estimate gradients

163

control_variate_from_function: The control variate to use

164

grad_estimator: The gradient estimator to compute gradients

165

params: Parameters for constructing the distribution

166

dist_builder: Constructor that builds a distribution from parameters

167

rng: PRNGKey for random sampling

168

num_samples: Number of samples for gradient computation

169

control_variate_state: State of the control variate (optional)

170

estimate_cv_coeffs: Whether to estimate optimal coefficients

171

estimate_cv_coeffs_num_samples: Number of samples for coefficient estimation

172

173

Returns:

174

tuple[Sequence[chex.Array], CvState]: Jacobians and updated control variate state

175

"""

176

```

177

178

## Usage Examples

179

180

```python

181

import optax

182

import jax

183

import jax.numpy as jnp

184

185

# Example: Score function gradient estimation

186

def objective_function(x):

187

return jnp.sum(x**2)

188

189

# Parameters for a Gaussian distribution

190

params = {'mean': jnp.array([1.0, 2.0]), 'log_std': jnp.array([0.0, 0.0])}

191

192

def gaussian_builder(mean, log_std):

193

return tfd.Normal(loc=mean, scale=jnp.exp(log_std))

194

195

rng = jax.random.PRNGKey(42)

196

num_samples = 1000

197

198

# Use score function estimator

199

gradients = optax.monte_carlo.score_function_jacobians(

200

function=objective_function,

201

params=params,

202

dist_builder=gaussian_builder,

203

rng=rng,

204

num_samples=num_samples

205

)

206

207

# Use pathwise estimator (requires differentiable function)

208

gradients_pathwise = optax.monte_carlo.pathwise_jacobians(

209

function=objective_function,

210

params=params,

211

dist_builder=gaussian_builder,

212

rng=rng,

213

num_samples=num_samples

214

)

215

```

216

217

## Gradient Estimation Methods Comparison

218

219

| Method | Function Requirements | Distribution Requirements | Variance |

220

|--------|----------------------|---------------------------|----------|

221

| Score Function | Any | Differentiable log-probability | High |

222

| Pathwise | Differentiable | Reparameterizable | Low |

223

| Measure-valued | Any | Gaussian only | Medium |

224

225

## Import

226

227

```python

228

import optax.monte_carlo

229

# or

230

from optax.monte_carlo import (

231

score_function_jacobians,

232

pathwise_jacobians,

233

measure_valued_jacobians,

234

moving_avg_baseline,

235

control_delta_method,

236

control_variates_jacobians

237

)

238

```

239

240

## Types

241

242

```python { .api }

243

# Control variate types

244

ControlVariate = tuple[

245

Callable, # Control variate computation function

246

Callable, # Expected value function

247

Callable # State update function

248

]

249

250

CvState = Any # Control variate state

251

```