0
# Tree Operations
1
2
JAX provides utilities for working with PyTrees (nested Python data structures containing arrays) through `jax.tree`. PyTrees are fundamental to JAX's functional programming approach and enable elegant handling of complex nested data structures like neural network parameters.
3
4
## Core Imports
5
6
```python
7
import jax.tree as jtree
8
from jax.tree import map, flatten, unflatten, reduce
9
```
10
11
## What are PyTrees?
12
13
PyTrees are nested Python data structures where:
14
- **Leaves** are arrays, scalars, or None
15
- **Nodes** are containers like lists, tuples, dicts, or custom classes
16
- The tree structure is preserved while operations apply to leaves
17
18
Common PyTree examples:
19
```python
20
# Simple trees
21
tree1 = [1, 2, 3] # List of scalars
22
tree2 = {'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])} # Dict of arrays
23
24
# Nested trees (neural network parameters)
25
params = {
26
'dense1': {'weight': jnp.zeros((784, 128)), 'bias': jnp.zeros(128)},
27
'dense2': {'weight': jnp.zeros((128, 10)), 'bias': jnp.zeros(10)}
28
}
29
30
# Mixed structures
31
state = {
32
'params': params,
33
'batch_stats': {'mean': jnp.zeros(128), 'var': jnp.ones(128)},
34
'step': 0 # Scalar leaf
35
}
36
```
37
38
## Capabilities
39
40
### Tree Traversal and Transformation
41
42
Apply functions to all leaves while preserving tree structure.
43
44
```python { .api }
45
def map(f, tree, *rest, is_leaf=None) -> Any:
46
"""
47
Apply function to all leaves of one or more trees.
48
49
Args:
50
f: Function to apply to leaves
51
tree: Primary PyTree
52
rest: Additional PyTrees with same structure
53
is_leaf: Optional function to determine what counts as leaf
54
55
Returns:
56
PyTree with same structure as input, f applied to all leaves
57
"""
58
59
def map_with_path(f, tree, *rest, is_leaf=None) -> Any:
60
"""
61
Apply function to leaves with path information.
62
63
Args:
64
f: Function taking (path, *leaves) as arguments
65
tree: Primary PyTree
66
rest: Additional PyTrees with same structure
67
is_leaf: Optional function to determine what counts as leaf
68
69
Returns:
70
PyTree with f applied to leaves, receiving path info
71
"""
72
73
def reduce(function, tree, initializer=None, is_leaf=None) -> Any:
74
"""
75
Reduce tree to single value by applying function to all leaves.
76
77
Args:
78
function: Binary function to combine leaves
79
tree: PyTree to reduce
80
initializer: Optional initial value for reduction
81
is_leaf: Optional function to determine what counts as leaf
82
83
Returns:
84
Single value from reducing all leaves
85
"""
86
87
def all(tree) -> bool:
88
"""
89
Return True if all leaves are truthy.
90
91
Args:
92
tree: PyTree to check
93
94
Returns:
95
Boolean indicating if all leaves are truthy
96
"""
97
```
98
99
Usage examples:
100
```python
101
# Apply function to all arrays in parameter tree
102
def init_weights(params):
103
return jtree.map(lambda x: x * 0.01, params)
104
105
# Element-wise operations on multiple trees
106
def add_trees(tree1, tree2):
107
return jtree.map(lambda x, y: x + y, tree1, tree2)
108
109
# Compute total number of parameters
110
def count_params(params):
111
return jtree.reduce(lambda count, x: count + x.size, params, initializer=0)
112
113
# Check if all gradients are finite
114
def all_finite(grads):
115
return jtree.all(jtree.map(jnp.isfinite, grads))
116
117
# Apply different functions based on path
118
def scale_by_path(path, param):
119
if 'bias' in path:
120
return param * 0.1 # Smaller learning rate for biases
121
else:
122
return param * 1.0
123
124
scaled_grads = jtree.map_with_path(scale_by_path, gradients)
125
```
126
127
### Tree Structure Operations
128
129
Flatten trees into lists and reconstruct them, useful for interfacing with optimizers and other libraries.
130
131
```python { .api }
132
def flatten(tree, is_leaf=None) -> tuple[list, Any]:
133
"""
134
Flatten PyTree into list of leaves and tree definition.
135
136
Args:
137
tree: PyTree to flatten
138
is_leaf: Optional function to determine what counts as leaf
139
140
Returns:
141
Tuple of (leaves_list, tree_definition)
142
"""
143
144
def unflatten(treedef, leaves) -> Any:
145
"""
146
Reconstruct PyTree from tree definition and leaves.
147
148
Args:
149
treedef: Tree definition from flatten()
150
leaves: List of leaf values
151
152
Returns:
153
Reconstructed PyTree with original structure
154
"""
155
156
def flatten_with_path(tree, is_leaf=None) -> tuple[list, list]:
157
"""
158
Flatten PyTree with path information for each leaf.
159
160
Args:
161
tree: PyTree to flatten
162
is_leaf: Optional function to determine what counts as leaf
163
164
Returns:
165
Tuple of (path_leaf_pairs, tree_definition)
166
"""
167
168
def leaves(tree, is_leaf=None) -> list:
169
"""
170
Get list of all leaves in PyTree.
171
172
Args:
173
tree: PyTree to extract leaves from
174
is_leaf: Optional function to determine what counts as leaf
175
176
Returns:
177
List containing all leaf values
178
"""
179
180
def leaves_with_path(tree, is_leaf=None) -> list:
181
"""
182
Get list of (path, leaf) pairs.
183
184
Args:
185
tree: PyTree to extract leaves from
186
is_leaf: Optional function to determine what counts as leaf
187
188
Returns:
189
List of (path, leaf) tuples
190
"""
191
192
def structure(tree, is_leaf=None) -> Any:
193
"""
194
Get tree structure (definition) without leaf values.
195
196
Args:
197
tree: PyTree to get structure from
198
is_leaf: Optional function to determine what counts as leaf
199
200
Returns:
201
Tree definition describing structure
202
"""
203
```
204
205
Usage examples:
206
```python
207
# Flatten for use with scipy optimizers
208
params = {'w': jnp.array([1, 2]), 'b': jnp.array([3])}
209
flat_params, tree_def = jtree.flatten(params)
210
print(flat_params) # [Array([1, 2]), Array([3])]
211
212
# Reconstruct after optimization
213
new_flat_params = [jnp.array([4, 5]), jnp.array([6])]
214
new_params = jtree.unflatten(tree_def, new_flat_params)
215
print(new_params) # {'w': Array([4, 5]), 'b': Array([6])}
216
217
# Get all parameter arrays
218
all_arrays = jtree.leaves(params)
219
220
# Inspect structure with paths
221
path_leaf_pairs = jtree.leaves_with_path(params)
222
print(path_leaf_pairs) # [(('w',), Array([1, 2])), (('b',), Array([3]))]
223
224
# Get structure for later use
225
structure_only = jtree.structure(params)
226
```
227
228
### Tree Transformation and Manipulation
229
230
Advanced operations for tree manipulation and structural transformations.
231
232
```python { .api }
233
def transpose(outer_treedef, inner_treedef, pytree_to_transpose) -> Any:
234
"""
235
Transpose nested PyTree structure.
236
237
Args:
238
outer_treedef: Target outer tree structure
239
inner_treedef: Target inner tree structure
240
pytree_to_transpose: PyTree to transpose
241
242
Returns:
243
PyTree with transposed nested structure
244
"""
245
```
246
247
Usage example:
248
```python
249
# Transpose structure: list of dicts -> dict of lists
250
list_of_dicts = [
251
{'a': 1, 'b': 2},
252
{'a': 3, 'b': 4},
253
{'a': 5, 'b': 6}
254
]
255
256
# Get structure definitions
257
outer_structure = jtree.structure(list_of_dicts) # List structure
258
inner_structure = jtree.structure({'a': None, 'b': None}) # Dict structure
259
260
# Transpose to dict of lists
261
dict_of_lists = jtree.transpose(inner_structure, outer_structure, list_of_dicts)
262
print(dict_of_lists) # {'a': [1, 3, 5], 'b': [2, 4, 6]}
263
```
264
265
### Broadcasting and Advanced Operations
266
267
```python { .api }
268
def broadcast(f, tree, *rest) -> Any:
269
"""
270
Broadcast function application across PyTree structures.
271
272
Args:
273
f: Function to broadcast
274
tree: Primary PyTree
275
rest: Additional PyTrees (may have different but compatible structures)
276
277
Returns:
278
PyTree result of broadcasting f across inputs
279
"""
280
```
281
282
## Custom PyTree Types
283
284
Register custom classes as PyTree nodes:
285
286
```python
287
import jax
288
289
# Register custom class as PyTree node
290
class MyContainer:
291
def __init__(self, data):
292
self.data = data
293
294
def __repr__(self):
295
return f"MyContainer({self.data})"
296
297
def container_flatten(container):
298
# Return (children, aux_data) where children are PyTrees
299
return (container.data.values(), tuple(container.data.keys()))
300
301
def container_unflatten(aux_data, children):
302
# Reconstruct from aux_data and children
303
return MyContainer(dict(zip(aux_data, children)))
304
305
# Register the PyTree node
306
jax.tree_util.register_pytree_node(
307
MyContainer,
308
container_flatten,
309
container_unflatten
310
)
311
312
# Now MyContainer works with tree operations
313
container = MyContainer({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
314
doubled = jtree.map(lambda x: x * 2, container)
315
print(doubled) # MyContainer({'x': Array([2, 4]), 'y': Array([6, 8])})
316
```
317
318
## Common Usage Patterns
319
320
### Neural Network Parameter Management
321
322
```python
323
# Initialize network parameters as PyTree
324
def init_mlp_params(layer_sizes, key):
325
params = {}
326
keys = jax.random.split(key, len(layer_sizes) - 1)
327
328
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
329
w_key, b_key = jax.random.split(keys[i])
330
params[f'layer_{i}'] = {
331
'weights': jax.random.normal(w_key, (in_size, out_size)) * 0.01,
332
'biases': jnp.zeros(out_size)
333
}
334
return params
335
336
# Apply gradients using tree operations
337
def update_params(params, grads, learning_rate):
338
return jtree.map(lambda p, g: p - learning_rate * g, params, grads)
339
340
# Compute parameter statistics
341
def param_stats(params):
342
flat_params = jtree.leaves(params)
343
total_params = sum(p.size for p in flat_params)
344
param_norm = jnp.sqrt(sum(jnp.sum(p**2) for p in flat_params))
345
return {'total_params': total_params, 'norm': param_norm}
346
```
347
348
### Optimizer State Management
349
350
```python
351
# Adam optimizer state as PyTree
352
def init_adam_state(params):
353
return {
354
'm': jtree.map(jnp.zeros_like, params), # First moment
355
'v': jtree.map(jnp.zeros_like, params), # Second moment
356
'step': 0
357
}
358
359
def adam_update(params, grads, state, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
360
step = state['step'] + 1
361
362
# Update biased moments
363
m = jtree.map(lambda m_prev, g: beta1 * m_prev + (1 - beta1) * g, state['m'], grads)
364
v = jtree.map(lambda v_prev, g: beta2 * v_prev + (1 - beta2) * g**2, state['v'], grads)
365
366
# Bias correction
367
m_hat = jtree.map(lambda m_val: m_val / (1 - beta1**step), m)
368
v_hat = jtree.map(lambda v_val: v_val / (1 - beta2**step), v)
369
370
# Parameter update
371
new_params = jtree.map(
372
lambda p, m_val, v_val: p - learning_rate * m_val / (jnp.sqrt(v_val) + eps),
373
params, m_hat, v_hat
374
)
375
376
new_state = {'m': m, 'v': v, 'step': step}
377
return new_params, new_state
378
```
379
380
### Batch Processing
381
382
```python
383
# Process batch of PyTrees
384
def process_batch(batch_trees):
385
# batch_trees is a list of PyTrees
386
# Convert to PyTree of batched arrays
387
return jtree.map(lambda *arrays: jnp.stack(arrays), *batch_trees)
388
389
# Example: batch of neural network inputs
390
batch_inputs = [
391
{'image': jnp.ones((28, 28)), 'label': 5},
392
{'image': jnp.zeros((28, 28)), 'label': 3},
393
{'image': jnp.ones((28, 28)) * 0.5, 'label': 1}
394
]
395
396
batched = process_batch(batch_inputs)
397
print(batched['image'].shape) # (3, 28, 28)
398
print(batched['label'].shape) # (3,)
399
```