0
# Utilities and Tree Operations
1
2
Utility functions for parameter updates, tree operations, numerical stability, and working with JAX pytrees. These functions provide essential infrastructure for building and using optimizers effectively.
3
4
## Capabilities
5
6
### Parameter Updates
7
8
#### Core Update Functions
9
10
```python { .api }
11
def apply_updates(params, updates):
12
"""
13
Apply parameter updates to current parameters.
14
15
Args:
16
params: Current parameters (pytree)
17
updates: Parameter updates (pytree with same structure as params)
18
19
Returns:
20
Updated parameters (pytree)
21
"""
22
23
def incremental_update(new_tensors, old_tensors, step_size):
24
"""
25
Compute incremental update between tensor sets.
26
27
Args:
28
new_tensors: New tensor values
29
old_tensors: Old tensor values
30
step_size: Step size for interpolation
31
32
Returns:
33
Incrementally updated tensors
34
"""
35
36
def periodic_update(new_tensors, old_tensors, steps, update_period):
37
"""
38
Update tensors periodically based on step count.
39
40
Args:
41
new_tensors: New tensor values
42
old_tensors: Old tensor values
43
steps: Current step count
44
update_period: Period for updates
45
46
Returns:
47
Conditionally updated tensors
48
"""
49
```
50
51
### Numerical Utilities
52
53
#### Safe Operations
54
55
```python { .api }
56
def safe_norm(x, min_norm=0.0, ord=None):
57
"""
58
Numerically stable norm computation.
59
60
Args:
61
x: Input tensor
62
min_norm: Minimum norm value for stability (default: 0.0)
63
ord: Norm order (None, 1, 2, 'fro', etc.) (default: None for L2)
64
65
Returns:
66
Norm value with numerical stability
67
"""
68
69
def safe_root_mean_squares(x, min_rms=0.0):
70
"""
71
Numerically stable root mean square computation.
72
73
Args:
74
x: Input tensor
75
min_rms: Minimum RMS value for stability (default: 0.0)
76
77
Returns:
78
RMS value with numerical stability
79
"""
80
81
def safe_increment(count):
82
"""
83
Safely increment counter with overflow protection.
84
85
Args:
86
count: Current counter value
87
88
Returns:
89
Incremented counter value
90
"""
91
92
def safe_int32_increment(count):
93
"""
94
Safely increment int32 counter with overflow protection.
95
96
Args:
97
count: Current int32 counter value
98
99
Returns:
100
Incremented int32 counter value
101
"""
102
```
103
104
### Linear Algebra
105
106
#### Matrix Operations
107
108
```python { .api }
109
def global_norm(updates):
110
"""
111
Compute global norm across all parameters in pytree.
112
113
Args:
114
updates: Parameter updates (pytree)
115
116
Returns:
117
Global norm scalar value
118
"""
119
120
def power_iteration(matrix, num_iters=10, error_tolerance=1e-6, precision=None):
121
"""
122
Compute dominant eigenvalue and eigenvector using power iteration.
123
124
Args:
125
matrix: Input matrix
126
num_iters: Maximum number of iterations (default: 10)
127
error_tolerance: Convergence tolerance (default: 1e-6)
128
precision: Numerical precision (default: None)
129
130
Returns:
131
Tuple of (eigenvalue, eigenvector)
132
"""
133
134
def matrix_inverse_pth_root(matrix, p, num_iters=15, ridge_epsilon=1e-6, error_tolerance=1e-6, precision=None):
135
"""
136
Compute matrix inverse p-th root using Newton's method.
137
138
Args:
139
matrix: Input positive definite matrix
140
p: Root order (e.g., 2 for square root)
141
num_iters: Maximum iterations (default: 15)
142
ridge_epsilon: Ridge regularization (default: 1e-6)
143
error_tolerance: Convergence tolerance (default: 1e-6)
144
precision: Numerical precision (default: None)
145
146
Returns:
147
Matrix inverse p-th root
148
"""
149
150
def nnls(a, b, max_iters=None, tol=1e-8):
151
"""
152
Non-negative least squares solver.
153
154
Args:
155
a: Coefficient matrix
156
b: Target vector
157
max_iters: Maximum iterations (default: None for auto)
158
tol: Convergence tolerance (default: 1e-8)
159
160
Returns:
161
Non-negative solution vector
162
"""
163
```
164
165
### Core Types and Base Functions
166
167
#### Base Transformations
168
169
```python { .api }
170
def identity():
171
"""
172
Identity transformation that passes gradients unchanged.
173
174
Returns:
175
GradientTransformation
176
"""
177
178
def set_to_zero():
179
"""
180
Transformation that sets all gradients to zero.
181
182
Returns:
183
GradientTransformation
184
"""
185
186
def stateless(f):
187
"""
188
Create stateless transformation from function.
189
190
Args:
191
f: Function to convert to transformation
192
193
Returns:
194
GradientTransformation
195
"""
196
197
def stateless_with_tree_map(f):
198
"""
199
Create stateless transformation with tree mapping.
200
201
Args:
202
f: Function to apply to each leaf of parameter tree
203
204
Returns:
205
GradientTransformation
206
"""
207
208
def with_extra_args_support(transformation):
209
"""
210
Add support for extra arguments to transformation.
211
212
Args:
213
transformation: Base transformation to extend
214
215
Returns:
216
GradientTransformationExtraArgs
217
"""
218
```
219
220
### Utility Functions
221
222
#### Gradient Processing
223
224
```python { .api }
225
def scale_gradient(inputs, scale):
226
"""
227
Scale gradients during forward/backward pass.
228
229
Args:
230
inputs: Input values (forward pass is identity)
231
scale: Scale factor for gradients in backward pass
232
233
Returns:
234
Inputs (unchanged in forward pass)
235
"""
236
237
def value_and_grad_from_state(fun, argnums=0, has_aux=False):
238
"""
239
Compute value and gradient while maintaining state.
240
241
Args:
242
fun: Function to differentiate
243
argnums: Argument indices to differentiate (default: 0)
244
has_aux: Whether function returns auxiliary data (default: False)
245
246
Returns:
247
Function that returns (value, grad) tuple
248
"""
249
```
250
251
#### Random Utilities
252
253
```python { .api }
254
def multi_normal(loc, scale_tril, random_key):
255
"""
256
Sample from multivariate normal distribution.
257
258
Args:
259
loc: Mean vector
260
scale_tril: Lower triangular scale matrix
261
random_key: JAX random key
262
263
Returns:
264
Random sample from multivariate normal
265
"""
266
```
267
268
### Tree Operations
269
270
#### Basic Tree Arithmetic
271
272
```python { .api }
273
# Tree-level operations in optax.tree module
274
def add(tree1, tree2):
275
"""Element-wise addition of two pytrees."""
276
277
def sub(tree1, tree2):
278
"""Element-wise subtraction of two pytrees."""
279
280
def mul(tree1, tree2):
281
"""Element-wise multiplication of two pytrees."""
282
283
def div(tree1, tree2):
284
"""Element-wise division of two pytrees."""
285
286
def scale(tree, scalar):
287
"""Scale all elements in pytree by scalar."""
288
289
def norm(tree, ord=2):
290
"""Compute norm of pytree."""
291
292
def sum(tree):
293
"""Sum all elements in pytree."""
294
295
def max(tree):
296
"""Find maximum element in pytree."""
297
```
298
299
#### Tree Utilities
300
301
```python { .api }
302
def zeros_like(tree):
303
"""Create pytree of zeros with same structure."""
304
305
def ones_like(tree):
306
"""Create pytree of ones with same structure."""
307
308
def full_like(tree, fill_value):
309
"""Create pytree filled with specified value."""
310
```
311
312
### Assignment Module
313
314
#### Hungarian Algorithm
315
316
```python { .api }
317
def hungarian_algorithm(cost_matrix):
318
"""
319
Hungarian algorithm for solving assignment problems.
320
321
Args:
322
cost_matrix: 2D cost matrix for assignments
323
324
Returns:
325
Optimal assignment indices
326
"""
327
```
328
329
### Tree Utils Module
330
331
#### Parameter Tree Manipulation
332
333
```python { .api }
334
def tree_map_params(fn, tree):
335
"""
336
Map function over parameters in pytree.
337
338
Args:
339
fn: Function to apply to each parameter
340
tree: Parameter pytree
341
342
Returns:
343
Transformed pytree
344
"""
345
346
def tree_bias_correction(moment, decay, count):
347
"""
348
Apply bias correction to moment estimates.
349
350
Args:
351
moment: Moment estimate
352
decay: Decay rate used for moment
353
count: Step count for bias correction
354
355
Returns:
356
Bias-corrected moment
357
"""
358
```
359
360
#### Moment Updates
361
362
```python { .api }
363
def tree_update_moment(updates, moments, decay, order):
364
"""
365
Update moment estimates for optimizer state.
366
367
Args:
368
updates: Current gradient updates
369
moments: Previous moment estimates
370
decay: Exponential decay rate
371
order: Moment order (1 for mean, 2 for variance)
372
373
Returns:
374
Updated moment estimates
375
"""
376
377
def tree_update_moment_per_elem_norm(updates, moments, decay, order):
378
"""
379
Update moments with per-element normalization.
380
381
Args:
382
updates: Current gradient updates
383
moments: Previous moment estimates
384
decay: Exponential decay rate
385
order: Moment order
386
387
Returns:
388
Updated moment estimates with per-element normalization
389
"""
390
391
def tree_update_infinity_moment(updates, moments, decay):
392
"""
393
Update infinity moments (max absolute values).
394
395
Args:
396
updates: Current gradient updates
397
moments: Previous infinity moments
398
decay: Exponential decay rate
399
400
Returns:
401
Updated infinity moments
402
"""
403
```
404
405
### Type Definitions
406
407
```python { .api }
408
# Type aliases
409
OptState = chex.ArrayTree # Optimizer state
410
Params = chex.ArrayTree # Model parameters
411
Updates = Params # Gradient updates
412
Schedule = Callable[[chex.Numeric], chex.Numeric] # Schedule function
413
ScalarOrSchedule = Union[float, jax.Array, Schedule] # Flexible numeric type
414
MaskOrFn = Union[Any, Callable[[Params], Any]] # Mask or masking function
415
416
# Function type definitions
417
TransformInitFn = Callable[[Params], OptState]
418
TransformUpdateFn = Callable[[Updates, OptState, Optional[Params]], Tuple[Updates, OptState]]
419
TransformUpdateExtraArgsFn = Callable[[Updates, OptState, Optional[Params], ...], Tuple[Updates, OptState]]
420
421
# Core classes
422
class GradientTransformation(NamedTuple):
423
"""Core gradient transformation with init and update functions."""
424
init: TransformInitFn
425
update: TransformUpdateFn
426
427
class GradientTransformationExtraArgs(NamedTuple):
428
"""Extended transformation supporting extra arguments."""
429
init: TransformInitFn
430
update: TransformUpdateExtraArgsFn
431
432
class EmptyState(NamedTuple):
433
"""Empty state for stateless transformations."""
434
pass
435
436
class FactoredState(NamedTuple):
437
"""State for factorized operations."""
438
count: chex.Array
439
v_row: chex.ArrayTree
440
v_col: chex.ArrayTree
441
```
442
443
## Usage Examples
444
445
### Basic Parameter Updates
446
447
```python
448
import jax.numpy as jnp
449
import optax
450
451
# Parameters and updates
452
params = {'w': jnp.ones((5, 3)), 'b': jnp.zeros((3,))}
453
updates = {'w': jnp.ones((5, 3)) * 0.01, 'b': jnp.ones((3,)) * 0.001}
454
455
# Apply updates
456
new_params = optax.apply_updates(params, updates)
457
458
# Compute global norm
459
grad_norm = optax.global_norm(updates)
460
print(f"Global gradient norm: {grad_norm}")
461
```
462
463
### Numerical Stability
464
465
```python
466
# Safe operations for numerical stability
467
x = jnp.array([1e-8, 1e-6, 1.0, 1e6])
468
469
safe_norm_val = optax.safe_norm(x, min_norm=1e-8)
470
safe_rms_val = optax.safe_root_mean_squares(x, min_rms=1e-8)
471
472
# Safe counting
473
step_count = jnp.array(2147483647, dtype=jnp.int32) # Near int32 max
474
next_count = optax.safe_int32_increment(step_count)
475
```
476
477
### Tree Operations
478
479
```python
480
# Tree arithmetic
481
tree1 = {'a': jnp.array([1, 2, 3]), 'b': jnp.array([4, 5])}
482
tree2 = {'a': jnp.array([6, 7, 8]), 'b': jnp.array([9, 10])}
483
484
# Element-wise operations
485
sum_tree = optax.tree.add(tree1, tree2)
486
scaled_tree = optax.tree.scale(tree1, 0.5)
487
tree_norm = optax.tree.norm(tree1)
488
489
# Tree utilities
490
zero_tree = optax.tree.zeros_like(tree1)
491
ones_tree = optax.tree.ones_like(tree1)
492
```
493
494
### Custom Transformations
495
496
```python
497
# Create custom stateless transformation
498
def my_scaling_fn(updates):
499
return jax.tree_map(lambda x: 0.01 * x, updates)
500
501
my_transform = optax.stateless(my_scaling_fn)
502
503
# Use with other transformations
504
optimizer = optax.chain(
505
optax.clip_by_global_norm(1.0),
506
my_transform,
507
optax.scale_by_adam()
508
)
509
```
510
511
### Advanced Usage
512
513
```python
514
# Matrix operations for second-order methods
515
def compute_preconditioner(gradients):
516
# Flatten gradients for matrix operations
517
flat_grads = jax.flatten_util.ravel_pytree(gradients)[0]
518
519
# Compute outer product approximation
520
outer_prod = jnp.outer(flat_grads, flat_grads)
521
522
# Compute matrix inverse square root
523
inv_sqrt = optax.matrix_inverse_pth_root(
524
outer_prod + 1e-6 * jnp.eye(len(flat_grads)),
525
p=2,
526
num_iters=10
527
)
528
529
return inv_sqrt
530
531
# Gradient scaling with state
532
def scale_with_state(inputs, state):
533
scale_factor = jnp.sqrt(state['step_count'])
534
return optax.scale_gradient(inputs, scale_factor)
535
```