or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

tree-operations.mddocs/

0

# Tree Operations

1

2

JAX provides utilities for working with PyTrees (nested Python data structures containing arrays) through `jax.tree`. PyTrees are fundamental to JAX's functional programming approach and enable elegant handling of complex nested data structures like neural network parameters.

3

4

## Core Imports

5

6

```python

7

import jax.tree as jtree

8

from jax.tree import map, flatten, unflatten, reduce

9

```

10

11

## What are PyTrees?

12

13

PyTrees are nested Python data structures where:

14

- **Leaves** are arrays, scalars, or None

15

- **Nodes** are containers like lists, tuples, dicts, or custom classes

16

- The tree structure is preserved while operations apply to leaves

17

18

Common PyTree examples:

19

```python

20

# Simple trees

21

tree1 = [1, 2, 3] # List of scalars

22

tree2 = {'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])} # Dict of arrays

23

24

# Nested trees (neural network parameters)

25

params = {

26

'dense1': {'weight': jnp.zeros((784, 128)), 'bias': jnp.zeros(128)},

27

'dense2': {'weight': jnp.zeros((128, 10)), 'bias': jnp.zeros(10)}

28

}

29

30

# Mixed structures

31

state = {

32

'params': params,

33

'batch_stats': {'mean': jnp.zeros(128), 'var': jnp.ones(128)},

34

'step': 0 # Scalar leaf

35

}

36

```

37

38

## Capabilities

39

40

### Tree Traversal and Transformation

41

42

Apply functions to all leaves while preserving tree structure.

43

44

```python { .api }

45

def map(f, tree, *rest, is_leaf=None) -> Any:

46

"""

47

Apply function to all leaves of one or more trees.

48

49

Args:

50

f: Function to apply to leaves

51

tree: Primary PyTree

52

rest: Additional PyTrees with same structure

53

is_leaf: Optional function to determine what counts as leaf

54

55

Returns:

56

PyTree with same structure as input, f applied to all leaves

57

"""

58

59

def map_with_path(f, tree, *rest, is_leaf=None) -> Any:

60

"""

61

Apply function to leaves with path information.

62

63

Args:

64

f: Function taking (path, *leaves) as arguments

65

tree: Primary PyTree

66

rest: Additional PyTrees with same structure

67

is_leaf: Optional function to determine what counts as leaf

68

69

Returns:

70

PyTree with f applied to leaves, receiving path info

71

"""

72

73

def reduce(function, tree, initializer=None, is_leaf=None) -> Any:

74

"""

75

Reduce tree to single value by applying function to all leaves.

76

77

Args:

78

function: Binary function to combine leaves

79

tree: PyTree to reduce

80

initializer: Optional initial value for reduction

81

is_leaf: Optional function to determine what counts as leaf

82

83

Returns:

84

Single value from reducing all leaves

85

"""

86

87

def all(tree) -> bool:

88

"""

89

Return True if all leaves are truthy.

90

91

Args:

92

tree: PyTree to check

93

94

Returns:

95

Boolean indicating if all leaves are truthy

96

"""

97

```

98

99

Usage examples:

100

```python

101

# Apply function to all arrays in parameter tree

102

def init_weights(params):

103

return jtree.map(lambda x: x * 0.01, params)

104

105

# Element-wise operations on multiple trees

106

def add_trees(tree1, tree2):

107

return jtree.map(lambda x, y: x + y, tree1, tree2)

108

109

# Compute total number of parameters

110

def count_params(params):

111

return jtree.reduce(lambda count, x: count + x.size, params, initializer=0)

112

113

# Check if all gradients are finite

114

def all_finite(grads):

115

return jtree.all(jtree.map(jnp.isfinite, grads))

116

117

# Apply different functions based on path

118

def scale_by_path(path, param):

119

if 'bias' in path:

120

return param * 0.1 # Smaller learning rate for biases

121

else:

122

return param * 1.0

123

124

scaled_grads = jtree.map_with_path(scale_by_path, gradients)

125

```

126

127

### Tree Structure Operations

128

129

Flatten trees into lists and reconstruct them, useful for interfacing with optimizers and other libraries.

130

131

```python { .api }

132

def flatten(tree, is_leaf=None) -> tuple[list, Any]:

133

"""

134

Flatten PyTree into list of leaves and tree definition.

135

136

Args:

137

tree: PyTree to flatten

138

is_leaf: Optional function to determine what counts as leaf

139

140

Returns:

141

Tuple of (leaves_list, tree_definition)

142

"""

143

144

def unflatten(treedef, leaves) -> Any:

145

"""

146

Reconstruct PyTree from tree definition and leaves.

147

148

Args:

149

treedef: Tree definition from flatten()

150

leaves: List of leaf values

151

152

Returns:

153

Reconstructed PyTree with original structure

154

"""

155

156

def flatten_with_path(tree, is_leaf=None) -> tuple[list, list]:

157

"""

158

Flatten PyTree with path information for each leaf.

159

160

Args:

161

tree: PyTree to flatten

162

is_leaf: Optional function to determine what counts as leaf

163

164

Returns:

165

Tuple of (path_leaf_pairs, tree_definition)

166

"""

167

168

def leaves(tree, is_leaf=None) -> list:

169

"""

170

Get list of all leaves in PyTree.

171

172

Args:

173

tree: PyTree to extract leaves from

174

is_leaf: Optional function to determine what counts as leaf

175

176

Returns:

177

List containing all leaf values

178

"""

179

180

def leaves_with_path(tree, is_leaf=None) -> list:

181

"""

182

Get list of (path, leaf) pairs.

183

184

Args:

185

tree: PyTree to extract leaves from

186

is_leaf: Optional function to determine what counts as leaf

187

188

Returns:

189

List of (path, leaf) tuples

190

"""

191

192

def structure(tree, is_leaf=None) -> Any:

193

"""

194

Get tree structure (definition) without leaf values.

195

196

Args:

197

tree: PyTree to get structure from

198

is_leaf: Optional function to determine what counts as leaf

199

200

Returns:

201

Tree definition describing structure

202

"""

203

```

204

205

Usage examples:

206

```python

207

# Flatten for use with scipy optimizers

208

params = {'w': jnp.array([1, 2]), 'b': jnp.array([3])}

209

flat_params, tree_def = jtree.flatten(params)

210

print(flat_params) # [Array([1, 2]), Array([3])]

211

212

# Reconstruct after optimization

213

new_flat_params = [jnp.array([4, 5]), jnp.array([6])]

214

new_params = jtree.unflatten(tree_def, new_flat_params)

215

print(new_params) # {'w': Array([4, 5]), 'b': Array([6])}

216

217

# Get all parameter arrays

218

all_arrays = jtree.leaves(params)

219

220

# Inspect structure with paths

221

path_leaf_pairs = jtree.leaves_with_path(params)

222

print(path_leaf_pairs) # [(('w',), Array([1, 2])), (('b',), Array([3]))]

223

224

# Get structure for later use

225

structure_only = jtree.structure(params)

226

```

227

228

### Tree Transformation and Manipulation

229

230

Advanced operations for tree manipulation and structural transformations.

231

232

```python { .api }

233

def transpose(outer_treedef, inner_treedef, pytree_to_transpose) -> Any:

234

"""

235

Transpose nested PyTree structure.

236

237

Args:

238

outer_treedef: Target outer tree structure

239

inner_treedef: Target inner tree structure

240

pytree_to_transpose: PyTree to transpose

241

242

Returns:

243

PyTree with transposed nested structure

244

"""

245

```

246

247

Usage example:

248

```python

249

# Transpose structure: list of dicts -> dict of lists

250

list_of_dicts = [

251

{'a': 1, 'b': 2},

252

{'a': 3, 'b': 4},

253

{'a': 5, 'b': 6}

254

]

255

256

# Get structure definitions

257

outer_structure = jtree.structure(list_of_dicts) # List structure

258

inner_structure = jtree.structure({'a': None, 'b': None}) # Dict structure

259

260

# Transpose to dict of lists

261

dict_of_lists = jtree.transpose(inner_structure, outer_structure, list_of_dicts)

262

print(dict_of_lists) # {'a': [1, 3, 5], 'b': [2, 4, 6]}

263

```

264

265

### Broadcasting and Advanced Operations

266

267

```python { .api }

268

def broadcast(f, tree, *rest) -> Any:

269

"""

270

Broadcast function application across PyTree structures.

271

272

Args:

273

f: Function to broadcast

274

tree: Primary PyTree

275

rest: Additional PyTrees (may have different but compatible structures)

276

277

Returns:

278

PyTree result of broadcasting f across inputs

279

"""

280

```

281

282

## Custom PyTree Types

283

284

Register custom classes as PyTree nodes:

285

286

```python

287

import jax

288

289

# Register custom class as PyTree node

290

class MyContainer:

291

def __init__(self, data):

292

self.data = data

293

294

def __repr__(self):

295

return f"MyContainer({self.data})"

296

297

def container_flatten(container):

298

# Return (children, aux_data) where children are PyTrees

299

return (container.data.values(), tuple(container.data.keys()))

300

301

def container_unflatten(aux_data, children):

302

# Reconstruct from aux_data and children

303

return MyContainer(dict(zip(aux_data, children)))

304

305

# Register the PyTree node

306

jax.tree_util.register_pytree_node(

307

MyContainer,

308

container_flatten,

309

container_unflatten

310

)

311

312

# Now MyContainer works with tree operations

313

container = MyContainer({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})

314

doubled = jtree.map(lambda x: x * 2, container)

315

print(doubled) # MyContainer({'x': Array([2, 4]), 'y': Array([6, 8])})

316

```

317

318

## Common Usage Patterns

319

320

### Neural Network Parameter Management

321

322

```python

323

# Initialize network parameters as PyTree

324

def init_mlp_params(layer_sizes, key):

325

params = {}

326

keys = jax.random.split(key, len(layer_sizes) - 1)

327

328

for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):

329

w_key, b_key = jax.random.split(keys[i])

330

params[f'layer_{i}'] = {

331

'weights': jax.random.normal(w_key, (in_size, out_size)) * 0.01,

332

'biases': jnp.zeros(out_size)

333

}

334

return params

335

336

# Apply gradients using tree operations

337

def update_params(params, grads, learning_rate):

338

return jtree.map(lambda p, g: p - learning_rate * g, params, grads)

339

340

# Compute parameter statistics

341

def param_stats(params):

342

flat_params = jtree.leaves(params)

343

total_params = sum(p.size for p in flat_params)

344

param_norm = jnp.sqrt(sum(jnp.sum(p**2) for p in flat_params))

345

return {'total_params': total_params, 'norm': param_norm}

346

```

347

348

### Optimizer State Management

349

350

```python

351

# Adam optimizer state as PyTree

352

def init_adam_state(params):

353

return {

354

'm': jtree.map(jnp.zeros_like, params), # First moment

355

'v': jtree.map(jnp.zeros_like, params), # Second moment

356

'step': 0

357

}

358

359

def adam_update(params, grads, state, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):

360

step = state['step'] + 1

361

362

# Update biased moments

363

m = jtree.map(lambda m_prev, g: beta1 * m_prev + (1 - beta1) * g, state['m'], grads)

364

v = jtree.map(lambda v_prev, g: beta2 * v_prev + (1 - beta2) * g**2, state['v'], grads)

365

366

# Bias correction

367

m_hat = jtree.map(lambda m_val: m_val / (1 - beta1**step), m)

368

v_hat = jtree.map(lambda v_val: v_val / (1 - beta2**step), v)

369

370

# Parameter update

371

new_params = jtree.map(

372

lambda p, m_val, v_val: p - learning_rate * m_val / (jnp.sqrt(v_val) + eps),

373

params, m_hat, v_hat

374

)

375

376

new_state = {'m': m, 'v': v, 'step': step}

377

return new_params, new_state

378

```

379

380

### Batch Processing

381

382

```python

383

# Process batch of PyTrees

384

def process_batch(batch_trees):

385

# batch_trees is a list of PyTrees

386

# Convert to PyTree of batched arrays

387

return jtree.map(lambda *arrays: jnp.stack(arrays), *batch_trees)

388

389

# Example: batch of neural network inputs

390

batch_inputs = [

391

{'image': jnp.ones((28, 28)), 'label': 5},

392

{'image': jnp.zeros((28, 28)), 'label': 3},

393

{'image': jnp.ones((28, 28)) * 0.5, 'label': 1}

394

]

395

396

batched = process_batch(batch_inputs)

397

print(batched['image'].shape) # (3, 28, 28)

398

print(batched['label'].shape) # (3,)

399

```