0
# Tree Manipulation (etree)
1
2
Universal tree manipulation utilities compatible with TensorFlow nest, JAX tree_utils, DeepMind tree, and pure Python data structures. Provides a unified API for working with nested data structures across different ML frameworks.
3
4
## Capabilities
5
6
### Core Tree Type
7
8
Type definition for tree structures.
9
10
```python { .api }
11
Tree = Any # Nested data structure (dict, list, tuple, or custom)
12
LeafFn = Callable[[Any], bool] # Function to determine what constitutes a leaf
13
TreeDef = Any # Tree structure definition from flatten operations
14
```
15
16
### Tree API Objects
17
18
Different backend implementations for tree operations.
19
20
```python { .api }
21
jax: TreeAPI # JAX tree operations backend
22
optree: TreeAPI # Optree backend
23
tree: TreeAPI # DeepMind tree backend
24
nest: TreeAPI # TensorFlow nest backend
25
py: TreeAPI # Pure Python backend (default)
26
```
27
28
### Core Tree Operations
29
30
The py API provides the primary tree manipulation functions.
31
32
```python { .api }
33
def map(
34
map_fn: Callable[..., Any],
35
*trees: Tree,
36
is_leaf: Optional[LeafFn] = None
37
) -> Tree:
38
"""
39
Apply function to all leaf values in tree structures.
40
41
Args:
42
map_fn: Function to apply to each leaf or set of leaves
43
*trees: Input tree structures (supports multiple trees)
44
is_leaf: Function to determine what constitutes a leaf
45
46
Returns:
47
Tree with function applied to all leaves
48
"""
49
50
def parallel_map(
51
map_fn: Callable[..., Any],
52
*trees: Tree,
53
num_threads: Optional[int] = None,
54
progress_bar: bool = False,
55
is_leaf: Optional[LeafFn] = None
56
) -> Tree:
57
"""
58
Apply function to all leaf values in parallel.
59
60
Args:
61
map_fn: Function to apply to each leaf or set of leaves
62
*trees: Input tree structures (supports multiple trees)
63
num_threads: Number of parallel threads to use
64
progress_bar: Whether to display a progress bar
65
is_leaf: Function to determine what constitutes a leaf
66
67
Returns:
68
Tree with function applied to all leaves in parallel
69
"""
70
71
def unzip(tree: Tree) -> Tree:
72
"""
73
Unzip a tree of tuples/lists into a tuple/list of trees.
74
75
Args:
76
tree: Tree containing tuples or lists
77
78
Returns:
79
Tuple/list of trees
80
"""
81
82
def stack(tree: Tree) -> Tree:
83
"""
84
Stack multiple trees into a single tree.
85
86
Args:
87
tree: Tree containing stackable elements
88
89
Returns:
90
Stacked tree structure
91
"""
92
93
def spec_like(
94
tree: Tree,
95
*,
96
ignore_other: bool = True
97
) -> Tree:
98
"""
99
Create a spec-like structure matching the tree shape.
100
101
Args:
102
tree: Input tree structure
103
ignore_other: Whether to ignore non-array types
104
105
Returns:
106
Spec structure matching input tree
107
"""
108
109
def copy(tree: Tree) -> Tree:
110
"""
111
Create a deep copy of the tree structure.
112
113
Args:
114
tree: Input tree structure
115
116
Returns:
117
Deep copy of the tree
118
"""
119
120
# Backend-specific methods (available via backend attribute)
121
def flatten(tree: Tree, *, is_leaf: Optional[LeafFn] = None) -> tuple[list, TreeDef]:
122
"""
123
Flatten a tree structure into a list of leaves and structure definition.
124
125
Args:
126
tree: Input tree structure
127
is_leaf: Function to determine what constitutes a leaf
128
129
Returns:
130
Tuple of (flat_sequence, tree_structure)
131
"""
132
133
def unflatten(structure: TreeDef, flat_sequence: list) -> Tree:
134
"""
135
Reconstruct a tree from flattened data and structure.
136
137
Args:
138
structure: Tree structure definition from flatten()
139
flat_sequence: Flattened list of leaf values
140
141
Returns:
142
Reconstructed tree structure
143
"""
144
145
def assert_same_structure(tree0: Tree, tree1: Tree) -> None:
146
"""
147
Assert that two trees have the same structure.
148
149
Args:
150
tree0: First tree
151
tree1: Second tree
152
153
Raises:
154
ValueError: If structures don't match
155
"""
156
```
157
158
### Backend Modules
159
160
Access to underlying backend implementations.
161
162
```python { .api }
163
backend: ModuleType # Backend implementations module
164
tree_utils: ModuleType # Core tree utility functions module
165
```
166
167
## Usage Examples
168
169
### Basic Tree Operations
170
171
```python
172
from etils import etree
173
174
# Define a nested data structure
175
data = {
176
'params': {
177
'weights': [[1.0, 2.0], [3.0, 4.0]],
178
'bias': [0.1, 0.2]
179
},
180
'config': {
181
'learning_rate': 0.01,
182
'batch_size': 32
183
}
184
}
185
186
# Apply function to all numeric values
187
doubled = etree.py.map(lambda x: x * 2 if isinstance(x, (int, float)) else x, data)
188
# Result: All numeric values doubled
189
190
# Deep copy the structure
191
data_copy = etree.py.copy(data)
192
```
193
194
### Working with Multiple Trees
195
196
```python
197
from etils import etree
198
199
# Multiple parameter sets
200
tree1 = {'a': [1, 2], 'b': {'c': 3}}
201
tree2 = {'a': [4, 5], 'b': {'c': 6}}
202
203
# Combine operations across trees
204
combined = etree.py.map(lambda x, y: x + y, tree1, tree2)
205
# Result: {'a': [5, 7], 'b': {'c': 9}}
206
```
207
208
### Framework Compatibility
209
210
```python
211
from etils import etree
212
import jax
213
import tensorflow as tf
214
215
# JAX compatibility
216
jax_tree = {'params': jax.numpy.array([1, 2, 3])}
217
processed_jax = etree.jax.map(lambda x: x * 2, jax_tree)
218
219
# TensorFlow compatibility
220
tf_tree = {'weights': tf.constant([1.0, 2.0, 3.0])}
221
processed_tf = etree.nest.map(lambda x: x * 2, tf_tree)
222
223
# Pure Python (default)
224
py_tree = {'data': [1, 2, 3]}
225
processed_py = etree.py.map(lambda x: x * 2, py_tree)
226
```
227
228
### Advanced Tree Operations
229
230
```python
231
from etils import etree
232
233
# Unzip paired data
234
paired_data = {
235
'train': [(x1, y1), (x2, y2), (x3, y3)],
236
'test': [(x4, y4), (x5, y5)]
237
}
238
x_data, y_data = etree.py.unzip(paired_data)
239
240
# Stack multiple examples
241
examples = [
242
{'features': [1, 2], 'label': 0},
243
{'features': [3, 4], 'label': 1},
244
{'features': [5, 6], 'label': 0}
245
]
246
batched = etree.py.stack(examples)
247
# Result: {'features': [[1,2], [3,4], [5,6]], 'label': [0, 1, 0]}
248
249
# Create spec structure
250
spec = etree.py.spec_like(data)
251
# Result: Structure matching data but with spec information
252
```
253
254
### Parallel Processing
255
256
```python
257
from etils import etree
258
import numpy as np
259
260
# Large data structure with expensive operations
261
large_data = {
262
'layer1': {'weights': np.random.rand(1000, 1000)},
263
'layer2': {'weights': np.random.rand(1000, 1000)},
264
'layer3': {'weights': np.random.rand(1000, 1000)}
265
}
266
267
# Expensive function (e.g., matrix operations)
268
def expensive_op(x):
269
return np.linalg.svd(x)[0] # SVD decomposition
270
271
# Apply in parallel for better performance
272
result = etree.py.parallel_map(expensive_op, large_data)
273
```