or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

application-framework.mdarray-types.mdcolab-integration.mddataclass-enhancements.mdindex.mdnumpy-utilities.mdpath-operations.mdpython-utilities.mdtree-manipulation.md

tree-manipulation.mddocs/

0

# Tree Manipulation (etree)

1

2

Universal tree manipulation utilities compatible with TensorFlow nest, JAX tree_utils, DeepMind tree, and pure Python data structures. Provides a unified API for working with nested data structures across different ML frameworks.

3

4

## Capabilities

5

6

### Core Tree Type

7

8

Type definition for tree structures.

9

10

```python { .api }

11

Tree = Any # Nested data structure (dict, list, tuple, or custom)

12

LeafFn = Callable[[Any], bool] # Function to determine what constitutes a leaf

13

TreeDef = Any # Tree structure definition from flatten operations

14

```

15

16

### Tree API Objects

17

18

Different backend implementations for tree operations.

19

20

```python { .api }

21

jax: TreeAPI # JAX tree operations backend

22

optree: TreeAPI # Optree backend

23

tree: TreeAPI # DeepMind tree backend

24

nest: TreeAPI # TensorFlow nest backend

25

py: TreeAPI # Pure Python backend (default)

26

```

27

28

### Core Tree Operations

29

30

The py API provides the primary tree manipulation functions.

31

32

```python { .api }

33

def map(

34

map_fn: Callable[..., Any],

35

*trees: Tree,

36

is_leaf: Optional[LeafFn] = None

37

) -> Tree:

38

"""

39

Apply function to all leaf values in tree structures.

40

41

Args:

42

map_fn: Function to apply to each leaf or set of leaves

43

*trees: Input tree structures (supports multiple trees)

44

is_leaf: Function to determine what constitutes a leaf

45

46

Returns:

47

Tree with function applied to all leaves

48

"""

49

50

def parallel_map(

51

map_fn: Callable[..., Any],

52

*trees: Tree,

53

num_threads: Optional[int] = None,

54

progress_bar: bool = False,

55

is_leaf: Optional[LeafFn] = None

56

) -> Tree:

57

"""

58

Apply function to all leaf values in parallel.

59

60

Args:

61

map_fn: Function to apply to each leaf or set of leaves

62

*trees: Input tree structures (supports multiple trees)

63

num_threads: Number of parallel threads to use

64

progress_bar: Whether to display a progress bar

65

is_leaf: Function to determine what constitutes a leaf

66

67

Returns:

68

Tree with function applied to all leaves in parallel

69

"""

70

71

def unzip(tree: Tree) -> Tree:

72

"""

73

Unzip a tree of tuples/lists into a tuple/list of trees.

74

75

Args:

76

tree: Tree containing tuples or lists

77

78

Returns:

79

Tuple/list of trees

80

"""

81

82

def stack(tree: Tree) -> Tree:

83

"""

84

Stack multiple trees into a single tree.

85

86

Args:

87

tree: Tree containing stackable elements

88

89

Returns:

90

Stacked tree structure

91

"""

92

93

def spec_like(

94

tree: Tree,

95

*,

96

ignore_other: bool = True

97

) -> Tree:

98

"""

99

Create a spec-like structure matching the tree shape.

100

101

Args:

102

tree: Input tree structure

103

ignore_other: Whether to ignore non-array types

104

105

Returns:

106

Spec structure matching input tree

107

"""

108

109

def copy(tree: Tree) -> Tree:

110

"""

111

Create a deep copy of the tree structure.

112

113

Args:

114

tree: Input tree structure

115

116

Returns:

117

Deep copy of the tree

118

"""

119

120

# Backend-specific methods (available via backend attribute)

121

def flatten(tree: Tree, *, is_leaf: Optional[LeafFn] = None) -> tuple[list, TreeDef]:

122

"""

123

Flatten a tree structure into a list of leaves and structure definition.

124

125

Args:

126

tree: Input tree structure

127

is_leaf: Function to determine what constitutes a leaf

128

129

Returns:

130

Tuple of (flat_sequence, tree_structure)

131

"""

132

133

def unflatten(structure: TreeDef, flat_sequence: list) -> Tree:

134

"""

135

Reconstruct a tree from flattened data and structure.

136

137

Args:

138

structure: Tree structure definition from flatten()

139

flat_sequence: Flattened list of leaf values

140

141

Returns:

142

Reconstructed tree structure

143

"""

144

145

def assert_same_structure(tree0: Tree, tree1: Tree) -> None:

146

"""

147

Assert that two trees have the same structure.

148

149

Args:

150

tree0: First tree

151

tree1: Second tree

152

153

Raises:

154

ValueError: If structures don't match

155

"""

156

```

157

158

### Backend Modules

159

160

Access to underlying backend implementations.

161

162

```python { .api }

163

backend: ModuleType # Backend implementations module

164

tree_utils: ModuleType # Core tree utility functions module

165

```

166

167

## Usage Examples

168

169

### Basic Tree Operations

170

171

```python

172

from etils import etree

173

174

# Define a nested data structure

175

data = {

176

'params': {

177

'weights': [[1.0, 2.0], [3.0, 4.0]],

178

'bias': [0.1, 0.2]

179

},

180

'config': {

181

'learning_rate': 0.01,

182

'batch_size': 32

183

}

184

}

185

186

# Apply function to all numeric values

187

doubled = etree.py.map(lambda x: x * 2 if isinstance(x, (int, float)) else x, data)

188

# Result: All numeric values doubled

189

190

# Deep copy the structure

191

data_copy = etree.py.copy(data)

192

```

193

194

### Working with Multiple Trees

195

196

```python

197

from etils import etree

198

199

# Multiple parameter sets

200

tree1 = {'a': [1, 2], 'b': {'c': 3}}

201

tree2 = {'a': [4, 5], 'b': {'c': 6}}

202

203

# Combine operations across trees

204

combined = etree.py.map(lambda x, y: x + y, tree1, tree2)

205

# Result: {'a': [5, 7], 'b': {'c': 9}}

206

```

207

208

### Framework Compatibility

209

210

```python

211

from etils import etree

212

import jax

213

import tensorflow as tf

214

215

# JAX compatibility

216

jax_tree = {'params': jax.numpy.array([1, 2, 3])}

217

processed_jax = etree.jax.map(lambda x: x * 2, jax_tree)

218

219

# TensorFlow compatibility

220

tf_tree = {'weights': tf.constant([1.0, 2.0, 3.0])}

221

processed_tf = etree.nest.map(lambda x: x * 2, tf_tree)

222

223

# Pure Python (default)

224

py_tree = {'data': [1, 2, 3]}

225

processed_py = etree.py.map(lambda x: x * 2, py_tree)

226

```

227

228

### Advanced Tree Operations

229

230

```python

231

from etils import etree

232

233

# Unzip paired data

234

paired_data = {

235

'train': [(x1, y1), (x2, y2), (x3, y3)],

236

'test': [(x4, y4), (x5, y5)]

237

}

238

x_data, y_data = etree.py.unzip(paired_data)

239

240

# Stack multiple examples

241

examples = [

242

{'features': [1, 2], 'label': 0},

243

{'features': [3, 4], 'label': 1},

244

{'features': [5, 6], 'label': 0}

245

]

246

batched = etree.py.stack(examples)

247

# Result: {'features': [[1,2], [3,4], [5,6]], 'label': [0, 1, 0]}

248

249

# Create spec structure

250

spec = etree.py.spec_like(data)

251

# Result: Structure matching data but with spec information

252

```

253

254

### Parallel Processing

255

256

```python

257

from etils import etree

258

import numpy as np

259

260

# Large data structure with expensive operations

261

large_data = {

262

'layer1': {'weights': np.random.rand(1000, 1000)},

263

'layer2': {'weights': np.random.rand(1000, 1000)},

264

'layer3': {'weights': np.random.rand(1000, 1000)}

265

}

266

267

# Expensive function (e.g., matrix operations)

268

def expensive_op(x):

269

return np.linalg.svd(x)[0] # SVD decomposition

270

271

# Apply in parallel for better performance

272

result = etree.py.parallel_map(expensive_op, large_data)

273

```