or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced-optimizers.mdassignment.mdcontrib.mdindex.mdlosses.mdmonte-carlo.mdoptimizers.mdperturbations.mdprojections.mdschedules.mdsecond-order.mdtransformations.mdtree-utilities.mdutilities.md

tree-utilities.mddocs/

0

# Tree Utilities

1

2

JAX PyTree manipulation utilities for working with nested parameter structures. These functions provide efficient operations on tree-structured data, which is common in neural network parameters and gradients.

3

4

## Capabilities

5

6

### Tree Manipulation

7

8

```python { .api }

9

def tree_cast(tree, dtype):

10

"""Cast all leaves of a tree to specified dtype."""

11

12

def tree_cast_like(tree, target_tree):

13

"""Cast tree leaves to match dtypes of target tree."""

14

15

def tree_dtype(tree):

16

"""Get dtype information for tree leaves."""

17

18

def tree_get(tree, path):

19

"""Get value at specified path in tree."""

20

21

def tree_set(tree, path, value):

22

"""Set value at specified path in tree."""

23

24

def tree_get_all_with_path(tree):

25

"""Get all (path, value) pairs from tree."""

26

```

27

28

### Tree Arithmetic

29

30

```python { .api }

31

def tree_add(tree_a, tree_b):

32

"""Element-wise addition of two trees."""

33

34

def tree_sub(tree_a, tree_b):

35

"""Element-wise subtraction of two trees."""

36

37

def tree_mul(tree_a, tree_b):

38

"""Element-wise multiplication of two trees."""

39

40

def tree_div(tree_a, tree_b):

41

"""Element-wise division of two trees."""

42

43

def tree_add_scale(tree, scalar, scaled_tree):

44

"""Compute tree + scalar * scaled_tree."""

45

46

def tree_scale(tree, scalar):

47

"""Scale all leaves of tree by scalar."""

48

```

49

50

### Tree Reductions

51

52

```python { .api }

53

def tree_sum(tree):

54

"""Sum all leaves in tree."""

55

56

def tree_max(tree):

57

"""Maximum value across all leaves."""

58

59

def tree_vdot(tree_a, tree_b):

60

"""Vector dot product of flattened trees."""

61

62

def tree_batch_shape(tree):

63

"""Get batch shape from tree structure."""

64

```

65

66

### Tree Creation

67

68

```python { .api }

69

def tree_zeros_like(tree):

70

"""Create tree of zeros with same structure."""

71

72

def tree_ones_like(tree):

73

"""Create tree of ones with same structure."""

74

75

def tree_full_like(tree, fill_value):

76

"""Create tree filled with specified value."""

77

78

def tree_random_like(tree, key):

79

"""Create tree of random values with same structure."""

80

```

81

82

### Complex Number Support

83

84

```python { .api }

85

def tree_real(tree):

86

"""Extract real parts of complex tree."""

87

88

def tree_conj(tree):

89

"""Complex conjugate of tree."""

90

91

def tree_where(condition, tree_a, tree_b):

92

"""Element-wise selection between trees."""

93

```

94

95

## Usage Examples

96

97

```python

98

import optax

99

import jax.numpy as jnp

100

import jax

101

102

# Example tree structure (neural network parameters)

103

params = {

104

'dense1': {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros(5)},

105

'dense2': {'weights': jnp.ones((5, 1)), 'bias': jnp.zeros(1)}

106

}

107

108

# Tree arithmetic operations

109

grads = optax.tree.tree_zeros_like(params)

110

scaled_grads = optax.tree.tree_scale(grads, 0.01)

111

updated_params = optax.tree.tree_sub(params, scaled_grads)

112

113

# Tree reductions

114

total_params = optax.tree.tree_sum(

115

optax.tree.tree_map_params(lambda x: x.size, params)

116

)

117

param_norm = jnp.sqrt(optax.tree.tree_vdot(params, params))

118

119

# Random initialization

120

key = jax.random.PRNGKey(42)

121

random_params = optax.tree.tree_random_like(params, key)

122

123

# Working with paths

124

weight_path = ('dense1', 'weights')

125

weights = optax.tree.tree_get(params, weight_path)

126

new_params = optax.tree.tree_set(params, weight_path, weights * 0.9)

127

```

128

129

## Import

130

131

```python

132

import optax.tree

133

# or

134

from optax.tree import tree_add, tree_scale, tree_zeros_like

135

```