0
# Perturbation-Based Optimization
1
2
Utilities for perturbation-based optimization that enable gradient-based optimization of non-differentiable functions. This module provides techniques to create differentiable approximations of functions using stochastic smoothing through noise perturbations.
3
4
## Capabilities
5
6
### Perturbed Function Creation
7
8
Creates differentiable approximations of potentially non-differentiable functions using stochastic perturbations.
9
10
```python { .api }
11
def make_perturbed_fun(
12
fun,
13
num_samples=1000,
14
sigma=0.1,
15
noise=Gumbel(),
16
use_baseline=True
17
):
18
"""
19
Creates a differentiable approximation of a function using stochastic perturbations.
20
21
Transforms a potentially non-differentiable function into a smoothed, differentiable
22
version by adding noise and averaging over multiple samples. Uses the score function
23
estimator (REINFORCE) to provide unbiased Monte-Carlo estimates of derivatives.
24
25
Args:
26
fun: The function to transform (pytree → pytree with JAX array leaves)
27
num_samples: Number of perturbed outputs to average over (default: 1000)
28
sigma: Scale of random perturbation (default: 0.1)
29
noise: Distribution object with sample and log_prob methods (default: Gumbel())
30
use_baseline: Whether to use unperturbed function value for variance reduction (default: True)
31
32
Returns:
33
Callable: New function with signature (PRNGKey, ArrayTree) → ArrayTree
34
"""
35
```
36
37
### Noise Distributions
38
39
#### Gumbel Distribution
40
41
Standard Gumbel distribution commonly used in perturbation-based optimization due to its mathematical properties.
42
43
```python { .api }
44
class Gumbel:
45
"""Gumbel distribution for perturbation-based optimization."""
46
47
def sample(self, key, sample_shape=(), dtype=float):
48
"""
49
Generate random samples from the Gumbel distribution.
50
51
Args:
52
key: PRNG key for random sampling
53
sample_shape: Shape of samples to generate (default: ())
54
dtype: Data type for samples (default: float)
55
56
Returns:
57
jax.Array: Gumbel-distributed random values
58
"""
59
60
def log_prob(self, inputs):
61
"""
62
Compute log probability density of inputs.
63
64
Args:
65
inputs: JAX array for which to compute log probabilities
66
67
Returns:
68
jax.Array: Log probabilities using formula -inputs - exp(-inputs)
69
"""
70
```
71
72
#### Normal Distribution
73
74
Standard normal distribution as an alternative noise source for perturbations.
75
76
```python { .api }
77
class Normal:
78
"""Normal (Gaussian) distribution for perturbation-based optimization."""
79
80
def sample(self, key, sample_shape=(), dtype=float):
81
"""
82
Generate random samples from the standard normal distribution.
83
84
Args:
85
key: PRNG key for random sampling
86
sample_shape: Shape of samples to generate (default: ())
87
dtype: Data type for samples (default: float)
88
89
Returns:
90
jax.Array: Normally-distributed random values (mean=0, std=1)
91
"""
92
93
def log_prob(self, inputs):
94
"""
95
Compute log probability density of inputs.
96
97
Args:
98
inputs: JAX array for which to compute log probabilities
99
100
Returns:
101
jax.Array: Log probabilities using formula -0.5 * inputs²
102
"""
103
```
104
105
## Usage Examples
106
107
### Basic Usage
108
109
```python
110
import jax
111
import jax.numpy as jnp
112
import optax
113
114
# Example: Making a non-differentiable ReLU function differentiable
115
def non_differentiable_fn(x):
116
return jnp.sum(jnp.maximum(x, 0.0)) # ReLU activation
117
118
# Create perturbed version
119
key = jax.random.PRNGKey(42)
120
perturbed_fn = optax.perturbations.make_perturbed_fun(
121
fun=non_differentiable_fn,
122
num_samples=1000,
123
sigma=0.1,
124
noise=optax.perturbations.Gumbel()
125
)
126
127
# Now we can compute gradients
128
x = jnp.array([-1.0, 0.5, 2.0])
129
gradient = jax.grad(perturbed_fn, argnums=1)(key, x)
130
print(f"Gradient: {gradient}")
131
```
132
133
### Using Different Noise Distributions
134
135
```python
136
# Using Gumbel noise (default)
137
gumbel_fn = optax.perturbations.make_perturbed_fun(
138
fun=non_differentiable_fn,
139
noise=optax.perturbations.Gumbel()
140
)
141
142
# Using Normal noise
143
normal_fn = optax.perturbations.make_perturbed_fun(
144
fun=non_differentiable_fn,
145
noise=optax.perturbations.Normal()
146
)
147
148
# Compare gradients from different noise distributions
149
key1, key2 = jax.random.split(key)
150
grad_gumbel = jax.grad(gumbel_fn, argnums=1)(key1, x)
151
grad_normal = jax.grad(normal_fn, argnums=1)(key2, x)
152
```
153
154
### Optimizing Hyperparameters
155
156
```python
157
# Adjust perturbation scale and sample count
158
fine_tuned_fn = optax.perturbations.make_perturbed_fun(
159
fun=non_differentiable_fn,
160
num_samples=5000, # More samples for better approximation
161
sigma=0.05, # Smaller perturbations for finer approximation
162
use_baseline=True # Use baseline for variance reduction
163
)
164
```
165
166
### Real-World Application: Optimizing Discrete Choices
167
168
```python
169
def discrete_objective(weights):
170
"""Example function with discrete operations."""
171
# Simulate some discrete decision-making process
172
scores = weights @ jnp.array([1.0, 2.0, 3.0])
173
best_choice = jnp.argmax(scores) # Non-differentiable
174
return -scores[best_choice] # Negative because we want to maximize
175
176
# Make it differentiable
177
differentiable_objective = optax.perturbations.make_perturbed_fun(
178
fun=discrete_objective,
179
num_samples=2000,
180
sigma=0.2
181
)
182
183
# Now we can use gradient-based optimization
184
def optimize_discrete_choice():
185
weights = jnp.array([0.1, 0.1, 0.1])
186
optimizer = optax.adam(0.01)
187
opt_state = optimizer.init(weights)
188
189
for step in range(100):
190
key = jax.random.PRNGKey(step)
191
loss_val, grads = jax.value_and_grad(differentiable_objective, argnums=1)(key, weights)
192
updates, opt_state = optimizer.update(grads, opt_state, weights)
193
weights = optax.apply_updates(weights, updates)
194
195
if step % 20 == 0:
196
print(f"Step {step}, Loss: {loss_val:.3f}")
197
198
return weights
199
200
optimized_weights = optimize_discrete_choice()
201
```
202
203
## Mathematical Foundation
204
205
The perturbation method is based on the score function estimator:
206
207
For a function f(x) and noise distribution p(ε), the perturbed function is:
208
```
209
F(x) = E[f(x + σε)]
210
```
211
212
The gradient is estimated using:
213
```
214
∇F(x) ≈ (1/N) Σᵢ f(x + σεᵢ) ∇ log p(εᵢ)
215
```
216
217
This provides an unbiased estimate of the gradient even when f is non-differentiable.
218
219
## When to Use Perturbations
220
221
- **Discrete Operations**: Functions containing argmax, argmin, or discrete sampling
222
- **Non-smooth Functions**: Functions with discontinuities or non-differentiable points
223
- **Combinatorial Optimization**: Problems requiring optimization over discrete choices
224
- **Reinforcement Learning**: Policy optimization with discrete action spaces
225
226
## Import
227
228
```python
229
import optax.perturbations
230
# or
231
from optax.perturbations import make_perturbed_fun, Gumbel, Normal
232
```
233
234
## Types
235
236
```python { .api }
237
# Distribution interface
238
class NoiseDistribution:
239
def sample(self, key, sample_shape=(), dtype=float) -> jax.Array:
240
"""Generate random samples."""
241
242
def log_prob(self, inputs: jax.Array) -> jax.Array:
243
"""Compute log probability density."""
244
```