Differentiate, compile, and transform Numpy code.
—
JAX provides utilities for working with PyTrees (nested Python data structures containing arrays) through jax.tree. PyTrees are fundamental to JAX's functional programming approach and enable elegant handling of complex nested data structures like neural network parameters.
import jax.tree as jtree
from jax.tree import map, flatten, unflatten, reducePyTrees are nested Python data structures where:
Common PyTree examples:
# Simple trees
tree1 = [1, 2, 3] # List of scalars
tree2 = {'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])} # Dict of arrays
# Nested trees (neural network parameters)
params = {
'dense1': {'weight': jnp.zeros((784, 128)), 'bias': jnp.zeros(128)},
'dense2': {'weight': jnp.zeros((128, 10)), 'bias': jnp.zeros(10)}
}
# Mixed structures
state = {
'params': params,
'batch_stats': {'mean': jnp.zeros(128), 'var': jnp.ones(128)},
'step': 0 # Scalar leaf
}Apply functions to all leaves while preserving tree structure.
def map(f, tree, *rest, is_leaf=None) -> Any:
"""
Apply function to all leaves of one or more trees.
Args:
f: Function to apply to leaves
tree: Primary PyTree
rest: Additional PyTrees with same structure
is_leaf: Optional function to determine what counts as leaf
Returns:
PyTree with same structure as input, f applied to all leaves
"""
def map_with_path(f, tree, *rest, is_leaf=None) -> Any:
"""
Apply function to leaves with path information.
Args:
f: Function taking (path, *leaves) as arguments
tree: Primary PyTree
rest: Additional PyTrees with same structure
is_leaf: Optional function to determine what counts as leaf
Returns:
PyTree with f applied to leaves, receiving path info
"""
def reduce(function, tree, initializer=None, is_leaf=None) -> Any:
"""
Reduce tree to single value by applying function to all leaves.
Args:
function: Binary function to combine leaves
tree: PyTree to reduce
initializer: Optional initial value for reduction
is_leaf: Optional function to determine what counts as leaf
Returns:
Single value from reducing all leaves
"""
def all(tree) -> bool:
"""
Return True if all leaves are truthy.
Args:
tree: PyTree to check
Returns:
Boolean indicating if all leaves are truthy
"""Usage examples:
# Apply function to all arrays in parameter tree
def init_weights(params):
return jtree.map(lambda x: x * 0.01, params)
# Element-wise operations on multiple trees
def add_trees(tree1, tree2):
return jtree.map(lambda x, y: x + y, tree1, tree2)
# Compute total number of parameters
def count_params(params):
return jtree.reduce(lambda count, x: count + x.size, params, initializer=0)
# Check if all gradients are finite
def all_finite(grads):
return jtree.all(jtree.map(jnp.isfinite, grads))
# Apply different functions based on path
def scale_by_path(path, param):
if 'bias' in path:
return param * 0.1 # Smaller learning rate for biases
else:
return param * 1.0
scaled_grads = jtree.map_with_path(scale_by_path, gradients)Flatten trees into lists and reconstruct them, useful for interfacing with optimizers and other libraries.
def flatten(tree, is_leaf=None) -> tuple[list, Any]:
"""
Flatten PyTree into list of leaves and tree definition.
Args:
tree: PyTree to flatten
is_leaf: Optional function to determine what counts as leaf
Returns:
Tuple of (leaves_list, tree_definition)
"""
def unflatten(treedef, leaves) -> Any:
"""
Reconstruct PyTree from tree definition and leaves.
Args:
treedef: Tree definition from flatten()
leaves: List of leaf values
Returns:
Reconstructed PyTree with original structure
"""
def flatten_with_path(tree, is_leaf=None) -> tuple[list, list]:
"""
Flatten PyTree with path information for each leaf.
Args:
tree: PyTree to flatten
is_leaf: Optional function to determine what counts as leaf
Returns:
Tuple of (path_leaf_pairs, tree_definition)
"""
def leaves(tree, is_leaf=None) -> list:
"""
Get list of all leaves in PyTree.
Args:
tree: PyTree to extract leaves from
is_leaf: Optional function to determine what counts as leaf
Returns:
List containing all leaf values
"""
def leaves_with_path(tree, is_leaf=None) -> list:
"""
Get list of (path, leaf) pairs.
Args:
tree: PyTree to extract leaves from
is_leaf: Optional function to determine what counts as leaf
Returns:
List of (path, leaf) tuples
"""
def structure(tree, is_leaf=None) -> Any:
"""
Get tree structure (definition) without leaf values.
Args:
tree: PyTree to get structure from
is_leaf: Optional function to determine what counts as leaf
Returns:
Tree definition describing structure
"""Usage examples:
# Flatten for use with scipy optimizers
params = {'w': jnp.array([1, 2]), 'b': jnp.array([3])}
flat_params, tree_def = jtree.flatten(params)
print(flat_params) # [Array([1, 2]), Array([3])]
# Reconstruct after optimization
new_flat_params = [jnp.array([4, 5]), jnp.array([6])]
new_params = jtree.unflatten(tree_def, new_flat_params)
print(new_params) # {'w': Array([4, 5]), 'b': Array([6])}
# Get all parameter arrays
all_arrays = jtree.leaves(params)
# Inspect structure with paths
path_leaf_pairs = jtree.leaves_with_path(params)
print(path_leaf_pairs) # [(('w',), Array([1, 2])), (('b',), Array([3]))]
# Get structure for later use
structure_only = jtree.structure(params)Advanced operations for tree manipulation and structural transformations.
def transpose(outer_treedef, inner_treedef, pytree_to_transpose) -> Any:
"""
Transpose nested PyTree structure.
Args:
outer_treedef: Target outer tree structure
inner_treedef: Target inner tree structure
pytree_to_transpose: PyTree to transpose
Returns:
PyTree with transposed nested structure
"""Usage example:
# Transpose structure: list of dicts -> dict of lists
list_of_dicts = [
{'a': 1, 'b': 2},
{'a': 3, 'b': 4},
{'a': 5, 'b': 6}
]
# Get structure definitions
outer_structure = jtree.structure(list_of_dicts) # List structure
inner_structure = jtree.structure({'a': None, 'b': None}) # Dict structure
# Transpose to dict of lists
dict_of_lists = jtree.transpose(inner_structure, outer_structure, list_of_dicts)
print(dict_of_lists) # {'a': [1, 3, 5], 'b': [2, 4, 6]}def broadcast(f, tree, *rest) -> Any:
"""
Broadcast function application across PyTree structures.
Args:
f: Function to broadcast
tree: Primary PyTree
rest: Additional PyTrees (may have different but compatible structures)
Returns:
PyTree result of broadcasting f across inputs
"""Register custom classes as PyTree nodes:
import jax
# Register custom class as PyTree node
class MyContainer:
def __init__(self, data):
self.data = data
def __repr__(self):
return f"MyContainer({self.data})"
def container_flatten(container):
# Return (children, aux_data) where children are PyTrees
return (container.data.values(), tuple(container.data.keys()))
def container_unflatten(aux_data, children):
# Reconstruct from aux_data and children
return MyContainer(dict(zip(aux_data, children)))
# Register the PyTree node
jax.tree_util.register_pytree_node(
MyContainer,
container_flatten,
container_unflatten
)
# Now MyContainer works with tree operations
container = MyContainer({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
doubled = jtree.map(lambda x: x * 2, container)
print(doubled) # MyContainer({'x': Array([2, 4]), 'y': Array([6, 8])})# Initialize network parameters as PyTree
def init_mlp_params(layer_sizes, key):
params = {}
keys = jax.random.split(key, len(layer_sizes) - 1)
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
w_key, b_key = jax.random.split(keys[i])
params[f'layer_{i}'] = {
'weights': jax.random.normal(w_key, (in_size, out_size)) * 0.01,
'biases': jnp.zeros(out_size)
}
return params
# Apply gradients using tree operations
def update_params(params, grads, learning_rate):
return jtree.map(lambda p, g: p - learning_rate * g, params, grads)
# Compute parameter statistics
def param_stats(params):
flat_params = jtree.leaves(params)
total_params = sum(p.size for p in flat_params)
param_norm = jnp.sqrt(sum(jnp.sum(p**2) for p in flat_params))
return {'total_params': total_params, 'norm': param_norm}# Adam optimizer state as PyTree
def init_adam_state(params):
return {
'm': jtree.map(jnp.zeros_like, params), # First moment
'v': jtree.map(jnp.zeros_like, params), # Second moment
'step': 0
}
def adam_update(params, grads, state, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
step = state['step'] + 1
# Update biased moments
m = jtree.map(lambda m_prev, g: beta1 * m_prev + (1 - beta1) * g, state['m'], grads)
v = jtree.map(lambda v_prev, g: beta2 * v_prev + (1 - beta2) * g**2, state['v'], grads)
# Bias correction
m_hat = jtree.map(lambda m_val: m_val / (1 - beta1**step), m)
v_hat = jtree.map(lambda v_val: v_val / (1 - beta2**step), v)
# Parameter update
new_params = jtree.map(
lambda p, m_val, v_val: p - learning_rate * m_val / (jnp.sqrt(v_val) + eps),
params, m_hat, v_hat
)
new_state = {'m': m, 'v': v, 'step': step}
return new_params, new_state# Process batch of PyTrees
def process_batch(batch_trees):
# batch_trees is a list of PyTrees
# Convert to PyTree of batched arrays
return jtree.map(lambda *arrays: jnp.stack(arrays), *batch_trees)
# Example: batch of neural network inputs
batch_inputs = [
{'image': jnp.ones((28, 28)), 'label': 5},
{'image': jnp.zeros((28, 28)), 'label': 3},
{'image': jnp.ones((28, 28)) * 0.5, 'label': 1}
]
batched = process_batch(batch_inputs)
print(batched['image'].shape) # (3, 28, 28)
print(batched['label'].shape) # (3,)Install with Tessl CLI
npx tessl i tessl/pypi-jax