A gradient processing and optimization library in JAX
—
Utilities for perturbation-based optimization that enable gradient-based optimization of non-differentiable functions. This module provides techniques to create differentiable approximations of functions using stochastic smoothing through noise perturbations.
Creates differentiable approximations of potentially non-differentiable functions using stochastic perturbations.
def make_perturbed_fun(
fun,
num_samples=1000,
sigma=0.1,
noise=Gumbel(),
use_baseline=True
):
"""
Creates a differentiable approximation of a function using stochastic perturbations.
Transforms a potentially non-differentiable function into a smoothed, differentiable
version by adding noise and averaging over multiple samples. Uses the score function
estimator (REINFORCE) to provide unbiased Monte-Carlo estimates of derivatives.
Args:
fun: The function to transform (pytree → pytree with JAX array leaves)
num_samples: Number of perturbed outputs to average over (default: 1000)
sigma: Scale of random perturbation (default: 0.1)
noise: Distribution object with sample and log_prob methods (default: Gumbel())
use_baseline: Whether to use unperturbed function value for variance reduction (default: True)
Returns:
Callable: New function with signature (PRNGKey, ArrayTree) → ArrayTree
"""Standard Gumbel distribution commonly used in perturbation-based optimization due to its mathematical properties.
class Gumbel:
"""Gumbel distribution for perturbation-based optimization."""
def sample(self, key, sample_shape=(), dtype=float):
"""
Generate random samples from the Gumbel distribution.
Args:
key: PRNG key for random sampling
sample_shape: Shape of samples to generate (default: ())
dtype: Data type for samples (default: float)
Returns:
jax.Array: Gumbel-distributed random values
"""
def log_prob(self, inputs):
"""
Compute log probability density of inputs.
Args:
inputs: JAX array for which to compute log probabilities
Returns:
jax.Array: Log probabilities using formula -inputs - exp(-inputs)
"""Standard normal distribution as an alternative noise source for perturbations.
class Normal:
"""Normal (Gaussian) distribution for perturbation-based optimization."""
def sample(self, key, sample_shape=(), dtype=float):
"""
Generate random samples from the standard normal distribution.
Args:
key: PRNG key for random sampling
sample_shape: Shape of samples to generate (default: ())
dtype: Data type for samples (default: float)
Returns:
jax.Array: Normally-distributed random values (mean=0, std=1)
"""
def log_prob(self, inputs):
"""
Compute log probability density of inputs.
Args:
inputs: JAX array for which to compute log probabilities
Returns:
jax.Array: Log probabilities using formula -0.5 * inputs²
"""import jax
import jax.numpy as jnp
import optax
# Example: Making a non-differentiable ReLU function differentiable
def non_differentiable_fn(x):
return jnp.sum(jnp.maximum(x, 0.0)) # ReLU activation
# Create perturbed version
key = jax.random.PRNGKey(42)
perturbed_fn = optax.perturbations.make_perturbed_fun(
fun=non_differentiable_fn,
num_samples=1000,
sigma=0.1,
noise=optax.perturbations.Gumbel()
)
# Now we can compute gradients
x = jnp.array([-1.0, 0.5, 2.0])
gradient = jax.grad(perturbed_fn, argnums=1)(key, x)
print(f"Gradient: {gradient}")# Using Gumbel noise (default)
gumbel_fn = optax.perturbations.make_perturbed_fun(
fun=non_differentiable_fn,
noise=optax.perturbations.Gumbel()
)
# Using Normal noise
normal_fn = optax.perturbations.make_perturbed_fun(
fun=non_differentiable_fn,
noise=optax.perturbations.Normal()
)
# Compare gradients from different noise distributions
key1, key2 = jax.random.split(key)
grad_gumbel = jax.grad(gumbel_fn, argnums=1)(key1, x)
grad_normal = jax.grad(normal_fn, argnums=1)(key2, x)# Adjust perturbation scale and sample count
fine_tuned_fn = optax.perturbations.make_perturbed_fun(
fun=non_differentiable_fn,
num_samples=5000, # More samples for better approximation
sigma=0.05, # Smaller perturbations for finer approximation
use_baseline=True # Use baseline for variance reduction
)def discrete_objective(weights):
"""Example function with discrete operations."""
# Simulate some discrete decision-making process
scores = weights @ jnp.array([1.0, 2.0, 3.0])
best_choice = jnp.argmax(scores) # Non-differentiable
return -scores[best_choice] # Negative because we want to maximize
# Make it differentiable
differentiable_objective = optax.perturbations.make_perturbed_fun(
fun=discrete_objective,
num_samples=2000,
sigma=0.2
)
# Now we can use gradient-based optimization
def optimize_discrete_choice():
weights = jnp.array([0.1, 0.1, 0.1])
optimizer = optax.adam(0.01)
opt_state = optimizer.init(weights)
for step in range(100):
key = jax.random.PRNGKey(step)
loss_val, grads = jax.value_and_grad(differentiable_objective, argnums=1)(key, weights)
updates, opt_state = optimizer.update(grads, opt_state, weights)
weights = optax.apply_updates(weights, updates)
if step % 20 == 0:
print(f"Step {step}, Loss: {loss_val:.3f}")
return weights
optimized_weights = optimize_discrete_choice()The perturbation method is based on the score function estimator:
For a function f(x) and noise distribution p(ε), the perturbed function is:
F(x) = E[f(x + σε)]The gradient is estimated using:
∇F(x) ≈ (1/N) Σᵢ f(x + σεᵢ) ∇ log p(εᵢ)This provides an unbiased estimate of the gradient even when f is non-differentiable.
import optax.perturbations
# or
from optax.perturbations import make_perturbed_fun, Gumbel, Normal# Distribution interface
class NoiseDistribution:
def sample(self, key, sample_shape=(), dtype=float) -> jax.Array:
"""Generate random samples."""
def log_prob(self, inputs: jax.Array) -> jax.Array:
"""Compute log probability density."""Install with Tessl CLI
npx tessl i tessl/pypi-optax