0
# Tree Utilities
1
2
JAX PyTree manipulation utilities for working with nested parameter structures. These functions provide efficient operations on tree-structured data, which is common in neural network parameters and gradients.
3
4
## Capabilities
5
6
### Tree Manipulation
7
8
```python { .api }
9
def tree_cast(tree, dtype):
10
"""Cast all leaves of a tree to specified dtype."""
11
12
def tree_cast_like(tree, target_tree):
13
"""Cast tree leaves to match dtypes of target tree."""
14
15
def tree_dtype(tree):
16
"""Get dtype information for tree leaves."""
17
18
def tree_get(tree, path):
19
"""Get value at specified path in tree."""
20
21
def tree_set(tree, path, value):
22
"""Set value at specified path in tree."""
23
24
def tree_get_all_with_path(tree):
25
"""Get all (path, value) pairs from tree."""
26
```
27
28
### Tree Arithmetic
29
30
```python { .api }
31
def tree_add(tree_a, tree_b):
32
"""Element-wise addition of two trees."""
33
34
def tree_sub(tree_a, tree_b):
35
"""Element-wise subtraction of two trees."""
36
37
def tree_mul(tree_a, tree_b):
38
"""Element-wise multiplication of two trees."""
39
40
def tree_div(tree_a, tree_b):
41
"""Element-wise division of two trees."""
42
43
def tree_add_scale(tree, scalar, scaled_tree):
44
"""Compute tree + scalar * scaled_tree."""
45
46
def tree_scale(tree, scalar):
47
"""Scale all leaves of tree by scalar."""
48
```
49
50
### Tree Reductions
51
52
```python { .api }
53
def tree_sum(tree):
54
"""Sum all leaves in tree."""
55
56
def tree_max(tree):
57
"""Maximum value across all leaves."""
58
59
def tree_vdot(tree_a, tree_b):
60
"""Vector dot product of flattened trees."""
61
62
def tree_batch_shape(tree):
63
"""Get batch shape from tree structure."""
64
```
65
66
### Tree Creation
67
68
```python { .api }
69
def tree_zeros_like(tree):
70
"""Create tree of zeros with same structure."""
71
72
def tree_ones_like(tree):
73
"""Create tree of ones with same structure."""
74
75
def tree_full_like(tree, fill_value):
76
"""Create tree filled with specified value."""
77
78
def tree_random_like(tree, key):
79
"""Create tree of random values with same structure."""
80
```
81
82
### Complex Number Support
83
84
```python { .api }
85
def tree_real(tree):
86
"""Extract real parts of complex tree."""
87
88
def tree_conj(tree):
89
"""Complex conjugate of tree."""
90
91
def tree_where(condition, tree_a, tree_b):
92
"""Element-wise selection between trees."""
93
```
94
95
## Usage Examples
96
97
```python
98
import optax
99
import jax.numpy as jnp
100
import jax
101
102
# Example tree structure (neural network parameters)
103
params = {
104
'dense1': {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros(5)},
105
'dense2': {'weights': jnp.ones((5, 1)), 'bias': jnp.zeros(1)}
106
}
107
108
# Tree arithmetic operations
109
grads = optax.tree.tree_zeros_like(params)
110
scaled_grads = optax.tree.tree_scale(grads, 0.01)
111
updated_params = optax.tree.tree_sub(params, scaled_grads)
112
113
# Tree reductions
114
total_params = optax.tree.tree_sum(
115
optax.tree.tree_map_params(lambda x: x.size, params)
116
)
117
param_norm = jnp.sqrt(optax.tree.tree_vdot(params, params))
118
119
# Random initialization
120
key = jax.random.PRNGKey(42)
121
random_params = optax.tree.tree_random_like(params, key)
122
123
# Working with paths
124
weight_path = ('dense1', 'weights')
125
weights = optax.tree.tree_get(params, weight_path)
126
new_params = optax.tree.tree_set(params, weight_path, weights * 0.9)
127
```
128
129
## Import
130
131
```python
132
import optax.tree
133
# or
134
from optax.tree import tree_add, tree_scale, tree_zeros_like
135
```