A gradient processing and optimization library in JAX
—
Utilities for efficient Monte Carlo gradient estimation methods. This module provides various techniques for approximating gradients of expectations, including score function estimators, pathwise estimators, and control variates for variance reduction.
Note: All functions in this module are deprecated and will be removed in Optax version 0.3.0.
Estimates gradients using the score function method (REINFORCE). Approximates ∇θ E{p(x;θ)} f(x) using E_{p(x;θ)} f(x) ∇_θ log p(x;θ).
def score_function_jacobians(
function,
params,
dist_builder,
rng,
num_samples
):
"""
Score function gradient estimation (REINFORCE).
Args:
function: Function f(x) for gradient estimation
params: Parameters for constructing the distribution
dist_builder: Constructor for building distributions from parameters
rng: PRNGKey for random sampling
num_samples: Number of samples for gradient computation
Returns:
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
"""Estimates gradients using the pathwise method (reparameterization trick). Approximates ∇θ E{p(x;θ)} f(x) using E_{p(ε)} ∇_θ f(g(ε,θ)) where x = g(ε,θ).
def pathwise_jacobians(
function,
params,
dist_builder,
rng,
num_samples
):
"""
Pathwise gradient estimation (reparameterization trick).
Args:
function: Function f(x) for gradient estimation (must be differentiable)
params: Parameters for constructing the distribution
dist_builder: Constructor for building distributions from parameters
rng: PRNGKey for random sampling
num_samples: Number of samples for gradient computation
Returns:
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
"""Estimates gradients using differences between related measures. Currently only supports Gaussian random variables.
def measure_valued_jacobians(
function,
params,
dist_builder,
rng,
num_samples,
coupling=True
):
"""
Measure-valued gradient estimation.
Args:
function: Function f(x) for gradient estimation
params: Parameters for constructing the distribution
dist_builder: Constructor for building distributions from parameters
rng: PRNGKey for random sampling
num_samples: Number of samples for gradient computation
coupling: Whether to use coupling for positive/negative samples (default: True)
Returns:
Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
"""Implements a moving average baseline for variance reduction.
def moving_avg_baseline(
function,
decay=0.99,
zero_debias=True,
use_decay_early_training_heuristic=True
):
"""
Moving average baseline control variate.
Args:
function: Function for which to compute the control variate
decay: Decay rate for the moving average (default: 0.99)
zero_debias: Whether to use zero debiasing (default: True)
use_decay_early_training_heuristic: Whether to use early training heuristic (default: True)
Returns:
ControlVariate: Tuple of three functions for computing control variate
"""Implements the control delta covariant method using second-order Taylor expansion.
def control_delta_method(function):
"""
Control delta covariant method control variate.
Args:
function: The function for which to compute the control variate
Returns:
ControlVariate: Tuple of three functions for computing control variate
"""Combines control variates with gradient estimators for variance reduction.
def control_variates_jacobians(
function,
control_variate_from_function,
grad_estimator,
params,
dist_builder,
rng,
num_samples,
control_variate_state=None,
estimate_cv_coeffs=False,
estimate_cv_coeffs_num_samples=20
):
"""
Gradient estimation using control variates for variance reduction.
Args:
function: Function f(x) for which to estimate gradients
control_variate_from_function: The control variate to use
grad_estimator: The gradient estimator to compute gradients
params: Parameters for constructing the distribution
dist_builder: Constructor that builds a distribution from parameters
rng: PRNGKey for random sampling
num_samples: Number of samples for gradient computation
control_variate_state: State of the control variate (optional)
estimate_cv_coeffs: Whether to estimate optimal coefficients
estimate_cv_coeffs_num_samples: Number of samples for coefficient estimation
Returns:
tuple[Sequence[chex.Array], CvState]: Jacobians and updated control variate state
"""import optax
import jax
import jax.numpy as jnp
# Example: Score function gradient estimation
def objective_function(x):
return jnp.sum(x**2)
# Parameters for a Gaussian distribution
params = {'mean': jnp.array([1.0, 2.0]), 'log_std': jnp.array([0.0, 0.0])}
def gaussian_builder(mean, log_std):
return tfd.Normal(loc=mean, scale=jnp.exp(log_std))
rng = jax.random.PRNGKey(42)
num_samples = 1000
# Use score function estimator
gradients = optax.monte_carlo.score_function_jacobians(
function=objective_function,
params=params,
dist_builder=gaussian_builder,
rng=rng,
num_samples=num_samples
)
# Use pathwise estimator (requires differentiable function)
gradients_pathwise = optax.monte_carlo.pathwise_jacobians(
function=objective_function,
params=params,
dist_builder=gaussian_builder,
rng=rng,
num_samples=num_samples
)| Method | Function Requirements | Distribution Requirements | Variance |
|---|---|---|---|
| Score Function | Any | Differentiable log-probability | High |
| Pathwise | Differentiable | Reparameterizable | Low |
| Measure-valued | Any | Gaussian only | Medium |
import optax.monte_carlo
# or
from optax.monte_carlo import (
score_function_jacobians,
pathwise_jacobians,
measure_valued_jacobians,
moving_avg_baseline,
control_delta_method,
control_variates_jacobians
)# Control variate types
ControlVariate = tuple[
Callable, # Control variate computation function
Callable, # Expected value function
Callable # State update function
]
CvState = Any # Control variate stateInstall with Tessl CLI
npx tessl i tessl/pypi-optax