or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

perturbations.mddocs/

0

# Perturbation-Based Optimization

1

2

Utilities for perturbation-based optimization that enable gradient-based optimization of non-differentiable functions. This module provides techniques to create differentiable approximations of functions using stochastic smoothing through noise perturbations.

3

4

## Capabilities

5

6

### Perturbed Function Creation

7

8

Creates differentiable approximations of potentially non-differentiable functions using stochastic perturbations.

9

10

```python { .api }

11

def make_perturbed_fun(

12

fun,

13

num_samples=1000,

14

sigma=0.1,

15

noise=Gumbel(),

16

use_baseline=True

17

):

18

"""

19

Creates a differentiable approximation of a function using stochastic perturbations.

20

21

Transforms a potentially non-differentiable function into a smoothed, differentiable

22

version by adding noise and averaging over multiple samples. Uses the score function

23

estimator (REINFORCE) to provide unbiased Monte-Carlo estimates of derivatives.

24

25

Args:

26

fun: The function to transform (pytree → pytree with JAX array leaves)

27

num_samples: Number of perturbed outputs to average over (default: 1000)

28

sigma: Scale of random perturbation (default: 0.1)

29

noise: Distribution object with sample and log_prob methods (default: Gumbel())

30

use_baseline: Whether to use unperturbed function value for variance reduction (default: True)

31

32

Returns:

33

Callable: New function with signature (PRNGKey, ArrayTree) → ArrayTree

34

"""

35

```

36

37

### Noise Distributions

38

39

#### Gumbel Distribution

40

41

Standard Gumbel distribution commonly used in perturbation-based optimization due to its mathematical properties.

42

43

```python { .api }

44

class Gumbel:

45

"""Gumbel distribution for perturbation-based optimization."""

46

47

def sample(self, key, sample_shape=(), dtype=float):

48

"""

49

Generate random samples from the Gumbel distribution.

50

51

Args:

52

key: PRNG key for random sampling

53

sample_shape: Shape of samples to generate (default: ())

54

dtype: Data type for samples (default: float)

55

56

Returns:

57

jax.Array: Gumbel-distributed random values

58

"""

59

60

def log_prob(self, inputs):

61

"""

62

Compute log probability density of inputs.

63

64

Args:

65

inputs: JAX array for which to compute log probabilities

66

67

Returns:

68

jax.Array: Log probabilities using formula -inputs - exp(-inputs)

69

"""

70

```

71

72

#### Normal Distribution

73

74

Standard normal distribution as an alternative noise source for perturbations.

75

76

```python { .api }

77

class Normal:

78

"""Normal (Gaussian) distribution for perturbation-based optimization."""

79

80

def sample(self, key, sample_shape=(), dtype=float):

81

"""

82

Generate random samples from the standard normal distribution.

83

84

Args:

85

key: PRNG key for random sampling

86

sample_shape: Shape of samples to generate (default: ())

87

dtype: Data type for samples (default: float)

88

89

Returns:

90

jax.Array: Normally-distributed random values (mean=0, std=1)

91

"""

92

93

def log_prob(self, inputs):

94

"""

95

Compute log probability density of inputs.

96

97

Args:

98

inputs: JAX array for which to compute log probabilities

99

100

Returns:

101

jax.Array: Log probabilities using formula -0.5 * inputs²

102

"""

103

```

104

105

## Usage Examples

106

107

### Basic Usage

108

109

```python

110

import jax

111

import jax.numpy as jnp

112

import optax

113

114

# Example: Making a non-differentiable ReLU function differentiable

115

def non_differentiable_fn(x):

116

return jnp.sum(jnp.maximum(x, 0.0)) # ReLU activation

117

118

# Create perturbed version

119

key = jax.random.PRNGKey(42)

120

perturbed_fn = optax.perturbations.make_perturbed_fun(

121

fun=non_differentiable_fn,

122

num_samples=1000,

123

sigma=0.1,

124

noise=optax.perturbations.Gumbel()

125

)

126

127

# Now we can compute gradients

128

x = jnp.array([-1.0, 0.5, 2.0])

129

gradient = jax.grad(perturbed_fn, argnums=1)(key, x)

130

print(f"Gradient: {gradient}")

131

```

132

133

### Using Different Noise Distributions

134

135

```python

136

# Using Gumbel noise (default)

137

gumbel_fn = optax.perturbations.make_perturbed_fun(

138

fun=non_differentiable_fn,

139

noise=optax.perturbations.Gumbel()

140

)

141

142

# Using Normal noise

143

normal_fn = optax.perturbations.make_perturbed_fun(

144

fun=non_differentiable_fn,

145

noise=optax.perturbations.Normal()

146

)

147

148

# Compare gradients from different noise distributions

149

key1, key2 = jax.random.split(key)

150

grad_gumbel = jax.grad(gumbel_fn, argnums=1)(key1, x)

151

grad_normal = jax.grad(normal_fn, argnums=1)(key2, x)

152

```

153

154

### Optimizing Hyperparameters

155

156

```python

157

# Adjust perturbation scale and sample count

158

fine_tuned_fn = optax.perturbations.make_perturbed_fun(

159

fun=non_differentiable_fn,

160

num_samples=5000, # More samples for better approximation

161

sigma=0.05, # Smaller perturbations for finer approximation

162

use_baseline=True # Use baseline for variance reduction

163

)

164

```

165

166

### Real-World Application: Optimizing Discrete Choices

167

168

```python

169

def discrete_objective(weights):

170

"""Example function with discrete operations."""

171

# Simulate some discrete decision-making process

172

scores = weights @ jnp.array([1.0, 2.0, 3.0])

173

best_choice = jnp.argmax(scores) # Non-differentiable

174

return -scores[best_choice] # Negative because we want to maximize

175

176

# Make it differentiable

177

differentiable_objective = optax.perturbations.make_perturbed_fun(

178

fun=discrete_objective,

179

num_samples=2000,

180

sigma=0.2

181

)

182

183

# Now we can use gradient-based optimization

184

def optimize_discrete_choice():

185

weights = jnp.array([0.1, 0.1, 0.1])

186

optimizer = optax.adam(0.01)

187

opt_state = optimizer.init(weights)

188

189

for step in range(100):

190

key = jax.random.PRNGKey(step)

191

loss_val, grads = jax.value_and_grad(differentiable_objective, argnums=1)(key, weights)

192

updates, opt_state = optimizer.update(grads, opt_state, weights)

193

weights = optax.apply_updates(weights, updates)

194

195

if step % 20 == 0:

196

print(f"Step {step}, Loss: {loss_val:.3f}")

197

198

return weights

199

200

optimized_weights = optimize_discrete_choice()

201

```

202

203

## Mathematical Foundation

204

205

The perturbation method is based on the score function estimator:

206

207

For a function f(x) and noise distribution p(ε), the perturbed function is:

208

```

209

F(x) = E[f(x + σε)]

210

```

211

212

The gradient is estimated using:

213

```

214

∇F(x) ≈ (1/N) Σᵢ f(x + σεᵢ) ∇ log p(εᵢ)

215

```

216

217

This provides an unbiased estimate of the gradient even when f is non-differentiable.

218

219

## When to Use Perturbations

220

221

- **Discrete Operations**: Functions containing argmax, argmin, or discrete sampling

222

- **Non-smooth Functions**: Functions with discontinuities or non-differentiable points

223

- **Combinatorial Optimization**: Problems requiring optimization over discrete choices

224

- **Reinforcement Learning**: Policy optimization with discrete action spaces

225

226

## Import

227

228

```python

229

import optax.perturbations

230

# or

231

from optax.perturbations import make_perturbed_fun, Gumbel, Normal

232

```

233

234

## Types

235

236

```python { .api }

237

# Distribution interface

238

class NoiseDistribution:

239

def sample(self, key, sample_shape=(), dtype=float) -> jax.Array:

240

"""Generate random samples."""

241

242

def log_prob(self, inputs: jax.Array) -> jax.Array:

243

"""Compute log probability density."""

244

```