CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

monte-carlo.mddocs/

Monte Carlo Gradient Estimation

Utilities for efficient Monte Carlo gradient estimation methods. This module provides various techniques for approximating gradients of expectations, including score function estimators, pathwise estimators, and control variates for variance reduction.

Note: All functions in this module are deprecated and will be removed in Optax version 0.3.0.

Capabilities

Score Function Gradient Estimation

REINFORCE Estimator

Estimates gradients using the score function method (REINFORCE). Approximates ∇θ E{p(x;θ)} f(x) using E_{p(x;θ)} f(x) ∇_θ log p(x;θ).

def score_function_jacobians(
    function, 
    params, 
    dist_builder, 
    rng, 
    num_samples
):
    """
    Score function gradient estimation (REINFORCE).
    
    Args:
        function: Function f(x) for gradient estimation
        params: Parameters for constructing the distribution  
        dist_builder: Constructor for building distributions from parameters
        rng: PRNGKey for random sampling
        num_samples: Number of samples for gradient computation
    
    Returns:
        Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
    """

Pathwise Gradient Estimation

Reparameterization Trick

Estimates gradients using the pathwise method (reparameterization trick). Approximates ∇θ E{p(x;θ)} f(x) using E_{p(ε)} ∇_θ f(g(ε,θ)) where x = g(ε,θ).

def pathwise_jacobians(
    function, 
    params, 
    dist_builder, 
    rng, 
    num_samples
):
    """
    Pathwise gradient estimation (reparameterization trick).
    
    Args:
        function: Function f(x) for gradient estimation (must be differentiable)
        params: Parameters for constructing the distribution
        dist_builder: Constructor for building distributions from parameters
        rng: PRNGKey for random sampling
        num_samples: Number of samples for gradient computation
    
    Returns:
        Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
    """

Measure-Valued Gradient Estimation

Measure Difference Method

Estimates gradients using differences between related measures. Currently only supports Gaussian random variables.

def measure_valued_jacobians(
    function, 
    params, 
    dist_builder, 
    rng, 
    num_samples, 
    coupling=True
):
    """
    Measure-valued gradient estimation.
    
    Args:
        function: Function f(x) for gradient estimation
        params: Parameters for constructing the distribution
        dist_builder: Constructor for building distributions from parameters
        rng: PRNGKey for random sampling
        num_samples: Number of samples for gradient computation
        coupling: Whether to use coupling for positive/negative samples (default: True)
    
    Returns:
        Sequence[chex.Array]: Tuple of jacobian vectors with shape num_samples x param.shape
    """

Control Variates

Moving Average Baseline

Implements a moving average baseline for variance reduction.

def moving_avg_baseline(
    function, 
    decay=0.99, 
    zero_debias=True, 
    use_decay_early_training_heuristic=True
):
    """
    Moving average baseline control variate.
    
    Args:
        function: Function for which to compute the control variate
        decay: Decay rate for the moving average (default: 0.99)
        zero_debias: Whether to use zero debiasing (default: True)
        use_decay_early_training_heuristic: Whether to use early training heuristic (default: True)
    
    Returns:
        ControlVariate: Tuple of three functions for computing control variate
    """

Control Delta Method

Implements the control delta covariant method using second-order Taylor expansion.

def control_delta_method(function):
    """
    Control delta covariant method control variate.
    
    Args:
        function: The function for which to compute the control variate
    
    Returns:
        ControlVariate: Tuple of three functions for computing control variate
    """

Control Variates with Jacobians

Combines control variates with gradient estimators for variance reduction.

def control_variates_jacobians(
    function,
    control_variate_from_function,
    grad_estimator,
    params,
    dist_builder,
    rng,
    num_samples,
    control_variate_state=None,
    estimate_cv_coeffs=False,
    estimate_cv_coeffs_num_samples=20
):
    """
    Gradient estimation using control variates for variance reduction.
    
    Args:
        function: Function f(x) for which to estimate gradients
        control_variate_from_function: The control variate to use
        grad_estimator: The gradient estimator to compute gradients
        params: Parameters for constructing the distribution
        dist_builder: Constructor that builds a distribution from parameters
        rng: PRNGKey for random sampling
        num_samples: Number of samples for gradient computation
        control_variate_state: State of the control variate (optional)
        estimate_cv_coeffs: Whether to estimate optimal coefficients
        estimate_cv_coeffs_num_samples: Number of samples for coefficient estimation
    
    Returns:
        tuple[Sequence[chex.Array], CvState]: Jacobians and updated control variate state
    """

Usage Examples

import optax
import jax
import jax.numpy as jnp

# Example: Score function gradient estimation
def objective_function(x):
    return jnp.sum(x**2)

# Parameters for a Gaussian distribution
params = {'mean': jnp.array([1.0, 2.0]), 'log_std': jnp.array([0.0, 0.0])}

def gaussian_builder(mean, log_std):
    return tfd.Normal(loc=mean, scale=jnp.exp(log_std))

rng = jax.random.PRNGKey(42)
num_samples = 1000

# Use score function estimator
gradients = optax.monte_carlo.score_function_jacobians(
    function=objective_function,
    params=params,
    dist_builder=gaussian_builder,
    rng=rng,
    num_samples=num_samples
)

# Use pathwise estimator (requires differentiable function)
gradients_pathwise = optax.monte_carlo.pathwise_jacobians(
    function=objective_function,
    params=params,
    dist_builder=gaussian_builder,
    rng=rng,
    num_samples=num_samples
)

Gradient Estimation Methods Comparison

MethodFunction RequirementsDistribution RequirementsVariance
Score FunctionAnyDifferentiable log-probabilityHigh
PathwiseDifferentiableReparameterizableLow
Measure-valuedAnyGaussian onlyMedium

Import

import optax.monte_carlo
# or
from optax.monte_carlo import (
    score_function_jacobians,
    pathwise_jacobians,
    measure_valued_jacobians,
    moving_avg_baseline,
    control_delta_method,
    control_variates_jacobians
)

Types

# Control variate types
ControlVariate = tuple[
    Callable,  # Control variate computation function
    Callable,  # Expected value function
    Callable   # State update function
]

CvState = Any  # Control variate state

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json