0
# Gradient Transformations
1
2
Building blocks for creating custom optimizers including scaling, clipping, noise addition, and momentum accumulation. These transformations can be combined using `chain()` to build custom optimization strategies with fine-grained control over gradient processing.
3
4
## Capabilities
5
6
### Chaining Transformations
7
8
Combine multiple gradient transformations into a single optimizer.
9
10
```python { .api }
11
def chain(*args):
12
"""
13
Chain multiple gradient transformations.
14
15
Args:
16
*args: Variable number of GradientTransformation objects
17
18
Returns:
19
GradientTransformationExtraArgs: Combined transformation
20
"""
21
22
def named_chain(**transformations):
23
"""
24
Chain transformations with names for easier debugging.
25
26
Args:
27
**transformations: Named GradientTransformation objects
28
29
Returns:
30
GradientTransformation: Combined transformation with named states
31
"""
32
```
33
34
### Scaling Transformations
35
36
#### Basic Scaling
37
38
```python { .api }
39
def scale(step_size):
40
"""
41
Scale updates by a constant factor.
42
43
Args:
44
step_size: Scaling factor (typically negative learning rate)
45
46
Returns:
47
GradientTransformation
48
"""
49
50
def scale_by_learning_rate(learning_rate):
51
"""
52
Scale updates by learning rate (with negative sign).
53
54
Args:
55
learning_rate: Learning rate value or schedule
56
57
Returns:
58
GradientTransformation
59
"""
60
61
def scale_by_schedule(schedule):
62
"""
63
Scale updates by a schedule function.
64
65
Args:
66
schedule: Schedule function taking step count and returning scale factor
67
68
Returns:
69
GradientTransformation
70
"""
71
```
72
73
#### Adaptive Scaling
74
75
```python { .api }
76
def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8, *, nesterov=False):
77
"""
78
Scale updates using Adam-style adaptive scaling.
79
80
Args:
81
b1: Exponential decay rate for first moment estimates (default: 0.9)
82
b2: Exponential decay rate for second moment estimates (default: 0.999)
83
eps: Small constant for numerical stability (default: 1e-8)
84
nesterov: Whether to use Nesterov momentum (default: False)
85
86
Returns:
87
GradientTransformation
88
"""
89
90
def scale_by_rms(decay=0.9, eps=1e-8):
91
"""
92
Scale updates by root mean square of gradients.
93
94
Args:
95
decay: Decay rate for moving average (default: 0.9)
96
eps: Small constant for numerical stability (default: 1e-8)
97
98
Returns:
99
GradientTransformation
100
"""
101
102
def scale_by_stddev(decay=0.9, eps=1e-8):
103
"""
104
Scale updates by standard deviation of gradients.
105
106
Args:
107
decay: Decay rate for moving average (default: 0.9)
108
eps: Small constant for numerical stability (default: 1e-8)
109
110
Returns:
111
GradientTransformation
112
"""
113
```
114
115
### Momentum and Accumulation
116
117
```python { .api }
118
def trace(decay, nesterov=False, accumulator_dtype=None):
119
"""
120
Add momentum/trace to gradient updates.
121
122
Args:
123
decay: Decay rate for momentum (default: 0.9)
124
nesterov: Whether to use Nesterov momentum (default: False)
125
accumulator_dtype: Data type for accumulator (default: None)
126
127
Returns:
128
GradientTransformation
129
"""
130
131
def ema(decay, debias=True, accumulator_dtype=None):
132
"""
133
Exponential moving average of parameters.
134
135
Args:
136
decay: Decay rate for moving average (default: 0.9)
137
debias: Whether to debias the moving average (default: True)
138
accumulator_dtype: Data type for accumulator (default: None)
139
140
Returns:
141
GradientTransformation
142
"""
143
```
144
145
### Gradient Clipping
146
147
```python { .api }
148
def clip(max_delta):
149
"""
150
Clip updates element-wise to maximum absolute value.
151
152
Args:
153
max_delta: Maximum absolute value for updates
154
155
Returns:
156
GradientTransformation
157
"""
158
159
def clip_by_global_norm(max_norm):
160
"""
161
Clip updates by global norm.
162
163
Args:
164
max_norm: Maximum global norm for updates
165
166
Returns:
167
GradientTransformation
168
"""
169
170
def clip_by_block_rms(threshold):
171
"""
172
Clip updates by block-wise RMS.
173
174
Args:
175
threshold: RMS threshold for clipping
176
177
Returns:
178
GradientTransformation
179
"""
180
181
def adaptive_grad_clip(clipping, eps=1e-3):
182
"""
183
Adaptive gradient clipping.
184
185
Args:
186
clipping: Clipping threshold
187
eps: Small constant for numerical stability (default: 1e-3)
188
189
Returns:
190
GradientTransformation
191
"""
192
193
def per_example_global_norm_clip(l2_norm_clip, single_batch_element=False):
194
"""
195
Per-example gradient clipping for differential privacy.
196
197
Args:
198
l2_norm_clip: L2 norm clipping threshold
199
single_batch_element: Whether input is a single batch element (default: False)
200
201
Returns:
202
GradientTransformation
203
"""
204
```
205
206
### Regularization
207
208
```python { .api }
209
def add_decayed_weights(weight_decay, mask=None):
210
"""
211
Add L2 weight decay (weight regularization).
212
213
Args:
214
weight_decay: Weight decay coefficient
215
mask: Optional mask for parameter selection
216
217
Returns:
218
GradientTransformation
219
"""
220
221
def add_noise(eta, gamma, seed):
222
"""
223
Add gradient noise for improved generalization.
224
225
Args:
226
eta: Noise scaling parameter
227
gamma: Annealing rate for noise
228
seed: Random seed
229
230
Returns:
231
GradientTransformation
232
"""
233
```
234
235
### Conditioning and Normalization
236
237
```python { .api }
238
def centralize():
239
"""
240
Centralize gradients by subtracting their mean.
241
242
Returns:
243
GradientTransformation
244
"""
245
246
def normalize_by_update_norm():
247
"""
248
Normalize updates by their norm.
249
250
Returns:
251
GradientTransformation
252
"""
253
254
def scale_by_trust_ratio():
255
"""
256
Scale updates by trust ratio (parameter norm / update norm).
257
258
Returns:
259
GradientTransformation
260
"""
261
```
262
263
### Conditional Operations
264
265
```python { .api }
266
def apply_if_finite(transformation):
267
"""
268
Apply transformation only if gradients are finite.
269
270
Args:
271
transformation: Transformation to apply conditionally
272
273
Returns:
274
GradientTransformation
275
"""
276
277
def apply_every(k, transformation):
278
"""
279
Apply transformation every k steps.
280
281
Args:
282
k: Step interval
283
transformation: Transformation to apply periodically
284
285
Returns:
286
GradientTransformation
287
"""
288
289
def conditionally_transform(condition_fn, transformation):
290
"""
291
Apply transformation based on condition function.
292
293
Args:
294
condition_fn: Function that returns boolean condition
295
transformation: Transformation to apply conditionally
296
297
Returns:
298
GradientTransformation
299
"""
300
```
301
302
### Parameter Partitioning
303
304
```python { .api }
305
def partition(selector_fn, *transformations):
306
"""
307
Apply different transformations to different parameter subsets.
308
309
Args:
310
selector_fn: Function to select parameter subsets
311
*transformations: Transformations for each subset
312
313
Returns:
314
GradientTransformation
315
"""
316
317
def masked(mask_fn, transformation):
318
"""
319
Apply transformation with parameter masking.
320
321
Args:
322
mask_fn: Function to generate parameter mask
323
transformation: Transformation to apply with mask
324
325
Returns:
326
GradientTransformation
327
"""
328
```
329
330
### Parameter Constraints
331
332
```python { .api }
333
def keep_params_nonnegative():
334
"""
335
Keep parameters non-negative by projecting to positive orthant.
336
337
Returns:
338
GradientTransformation
339
"""
340
341
def zero_nans():
342
"""
343
Set NaN gradients to zero.
344
345
Returns:
346
GradientTransformation
347
"""
348
```
349
350
### Multi-Step Accumulation
351
352
```python { .api }
353
class MultiSteps:
354
"""Multi-step gradient accumulation."""
355
356
def __init__(self, every_k_schedule, use_grad_mean=True):
357
"""
358
Initialize multi-step accumulation.
359
360
Args:
361
every_k_schedule: Schedule for accumulation steps
362
use_grad_mean: Whether to use gradient mean instead of sum (default: True)
363
"""
364
365
def skip_not_finite(updates, state, params=None):
366
"""
367
Skip updates that are not finite.
368
369
Args:
370
updates: Gradient updates
371
state: Optimizer state
372
params: Optional parameters
373
374
Returns:
375
Tuple of (updates, state)
376
"""
377
378
def skip_large_updates(updates, state, max_norm):
379
"""
380
Skip updates with norm larger than threshold.
381
382
Args:
383
updates: Gradient updates
384
state: Optimizer state
385
max_norm: Maximum allowed update norm
386
387
Returns:
388
Tuple of (updates, state)
389
"""
390
```
391
392
## Usage Examples
393
394
### Custom Optimizer with Chaining
395
396
```python
397
import optax
398
399
# Create custom optimizer by chaining transformations
400
custom_optimizer = optax.chain(
401
optax.clip_by_global_norm(1.0), # Gradient clipping
402
optax.add_decayed_weights(weight_decay=1e-4), # Weight decay
403
optax.scale_by_adam(b1=0.9, b2=0.999), # Adam scaling
404
optax.scale(-0.001) # Learning rate
405
)
406
407
# Initialize with parameters
408
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros((5,))}
409
opt_state = custom_optimizer.init(params)
410
```
411
412
### Conditional and Partitioned Updates
413
414
```python
415
# Apply different learning rates to different parameter groups
416
def is_bias(path, param):
417
return 'bias' in path
418
419
bias_tx = optax.scale(-0.01) # Higher learning rate for biases
420
weight_tx = optax.scale(-0.001) # Lower learning rate for weights
421
422
partitioned_optimizer = optax.partition(is_bias, bias_tx, weight_tx)
423
424
# Apply transformation only every 5 steps
425
sparse_optimizer = optax.apply_every(5, optax.adam(0.001))
426
```
427
428
### Robust Training Setup
429
430
```python
431
# Robust optimizer with multiple safeguards
432
robust_optimizer = optax.chain(
433
optax.clip_by_global_norm(1.0), # Prevent exploding gradients
434
optax.apply_if_finite( # Skip non-finite updates
435
optax.chain(
436
optax.centralize(), # Center gradients
437
optax.scale_by_adam(), # Adaptive scaling
438
optax.add_decayed_weights(1e-4), # Weight regularization
439
)
440
),
441
optax.scale_by_schedule( # Learning rate schedule
442
optax.cosine_decay_schedule(0.001, 1000)
443
)
444
)
445
```