CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

perturbations.mddocs/

Perturbation-Based Optimization

Utilities for perturbation-based optimization that enable gradient-based optimization of non-differentiable functions. This module provides techniques to create differentiable approximations of functions using stochastic smoothing through noise perturbations.

Capabilities

Perturbed Function Creation

Creates differentiable approximations of potentially non-differentiable functions using stochastic perturbations.

def make_perturbed_fun(
    fun, 
    num_samples=1000, 
    sigma=0.1, 
    noise=Gumbel(), 
    use_baseline=True
):
    """
    Creates a differentiable approximation of a function using stochastic perturbations.
    
    Transforms a potentially non-differentiable function into a smoothed, differentiable
    version by adding noise and averaging over multiple samples. Uses the score function
    estimator (REINFORCE) to provide unbiased Monte-Carlo estimates of derivatives.
    
    Args:
        fun: The function to transform (pytree → pytree with JAX array leaves)
        num_samples: Number of perturbed outputs to average over (default: 1000)
        sigma: Scale of random perturbation (default: 0.1)
        noise: Distribution object with sample and log_prob methods (default: Gumbel())
        use_baseline: Whether to use unperturbed function value for variance reduction (default: True)
    
    Returns:
        Callable: New function with signature (PRNGKey, ArrayTree) → ArrayTree
    """

Noise Distributions

Gumbel Distribution

Standard Gumbel distribution commonly used in perturbation-based optimization due to its mathematical properties.

class Gumbel:
    """Gumbel distribution for perturbation-based optimization."""
    
    def sample(self, key, sample_shape=(), dtype=float):
        """
        Generate random samples from the Gumbel distribution.
        
        Args:
            key: PRNG key for random sampling
            sample_shape: Shape of samples to generate (default: ())
            dtype: Data type for samples (default: float)
        
        Returns:
            jax.Array: Gumbel-distributed random values
        """
    
    def log_prob(self, inputs):
        """
        Compute log probability density of inputs.
        
        Args:
            inputs: JAX array for which to compute log probabilities
        
        Returns:
            jax.Array: Log probabilities using formula -inputs - exp(-inputs)
        """

Normal Distribution

Standard normal distribution as an alternative noise source for perturbations.

class Normal:
    """Normal (Gaussian) distribution for perturbation-based optimization."""
    
    def sample(self, key, sample_shape=(), dtype=float):
        """
        Generate random samples from the standard normal distribution.
        
        Args:
            key: PRNG key for random sampling
            sample_shape: Shape of samples to generate (default: ())
            dtype: Data type for samples (default: float)
        
        Returns:
            jax.Array: Normally-distributed random values (mean=0, std=1)
        """
    
    def log_prob(self, inputs):
        """
        Compute log probability density of inputs.
        
        Args:
            inputs: JAX array for which to compute log probabilities
        
        Returns:
            jax.Array: Log probabilities using formula -0.5 * inputs²
        """

Usage Examples

Basic Usage

import jax
import jax.numpy as jnp
import optax

# Example: Making a non-differentiable ReLU function differentiable
def non_differentiable_fn(x):
    return jnp.sum(jnp.maximum(x, 0.0))  # ReLU activation

# Create perturbed version
key = jax.random.PRNGKey(42)
perturbed_fn = optax.perturbations.make_perturbed_fun(
    fun=non_differentiable_fn,
    num_samples=1000,
    sigma=0.1,
    noise=optax.perturbations.Gumbel()
)

# Now we can compute gradients
x = jnp.array([-1.0, 0.5, 2.0])
gradient = jax.grad(perturbed_fn, argnums=1)(key, x)
print(f"Gradient: {gradient}")

Using Different Noise Distributions

# Using Gumbel noise (default)
gumbel_fn = optax.perturbations.make_perturbed_fun(
    fun=non_differentiable_fn,
    noise=optax.perturbations.Gumbel()
)

# Using Normal noise
normal_fn = optax.perturbations.make_perturbed_fun(
    fun=non_differentiable_fn,
    noise=optax.perturbations.Normal()
)

# Compare gradients from different noise distributions
key1, key2 = jax.random.split(key)
grad_gumbel = jax.grad(gumbel_fn, argnums=1)(key1, x)
grad_normal = jax.grad(normal_fn, argnums=1)(key2, x)

Optimizing Hyperparameters

# Adjust perturbation scale and sample count
fine_tuned_fn = optax.perturbations.make_perturbed_fun(
    fun=non_differentiable_fn,
    num_samples=5000,  # More samples for better approximation
    sigma=0.05,        # Smaller perturbations for finer approximation
    use_baseline=True  # Use baseline for variance reduction
)

Real-World Application: Optimizing Discrete Choices

def discrete_objective(weights):
    """Example function with discrete operations."""
    # Simulate some discrete decision-making process
    scores = weights @ jnp.array([1.0, 2.0, 3.0])
    best_choice = jnp.argmax(scores)  # Non-differentiable
    return -scores[best_choice]  # Negative because we want to maximize

# Make it differentiable
differentiable_objective = optax.perturbations.make_perturbed_fun(
    fun=discrete_objective,
    num_samples=2000,
    sigma=0.2
)

# Now we can use gradient-based optimization
def optimize_discrete_choice():
    weights = jnp.array([0.1, 0.1, 0.1])
    optimizer = optax.adam(0.01)
    opt_state = optimizer.init(weights)
    
    for step in range(100):
        key = jax.random.PRNGKey(step)
        loss_val, grads = jax.value_and_grad(differentiable_objective, argnums=1)(key, weights)
        updates, opt_state = optimizer.update(grads, opt_state, weights)
        weights = optax.apply_updates(weights, updates)
        
        if step % 20 == 0:
            print(f"Step {step}, Loss: {loss_val:.3f}")
    
    return weights

optimized_weights = optimize_discrete_choice()

Mathematical Foundation

The perturbation method is based on the score function estimator:

For a function f(x) and noise distribution p(ε), the perturbed function is:

F(x) = E[f(x + σε)]

The gradient is estimated using:

∇F(x) ≈ (1/N) Σᵢ f(x + σεᵢ) ∇ log p(εᵢ)

This provides an unbiased estimate of the gradient even when f is non-differentiable.

When to Use Perturbations

  • Discrete Operations: Functions containing argmax, argmin, or discrete sampling
  • Non-smooth Functions: Functions with discontinuities or non-differentiable points
  • Combinatorial Optimization: Problems requiring optimization over discrete choices
  • Reinforcement Learning: Policy optimization with discrete action spaces

Import

import optax.perturbations
# or
from optax.perturbations import make_perturbed_fun, Gumbel, Normal

Types

# Distribution interface
class NoiseDistribution:
    def sample(self, key, sample_shape=(), dtype=float) -> jax.Array:
        """Generate random samples."""
    
    def log_prob(self, inputs: jax.Array) -> jax.Array:
        """Compute log probability density."""

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json