CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

tree-operations.mddocs/

Tree Operations

JAX provides utilities for working with PyTrees (nested Python data structures containing arrays) through jax.tree. PyTrees are fundamental to JAX's functional programming approach and enable elegant handling of complex nested data structures like neural network parameters.

Core Imports

import jax.tree as jtree
from jax.tree import map, flatten, unflatten, reduce

What are PyTrees?

PyTrees are nested Python data structures where:

  • Leaves are arrays, scalars, or None
  • Nodes are containers like lists, tuples, dicts, or custom classes
  • The tree structure is preserved while operations apply to leaves

Common PyTree examples:

# Simple trees
tree1 = [1, 2, 3]  # List of scalars
tree2 = {'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])}  # Dict of arrays

# Nested trees (neural network parameters)
params = {
    'dense1': {'weight': jnp.zeros((784, 128)), 'bias': jnp.zeros(128)},
    'dense2': {'weight': jnp.zeros((128, 10)), 'bias': jnp.zeros(10)}
}

# Mixed structures  
state = {
    'params': params,
    'batch_stats': {'mean': jnp.zeros(128), 'var': jnp.ones(128)},
    'step': 0  # Scalar leaf
}

Capabilities

Tree Traversal and Transformation

Apply functions to all leaves while preserving tree structure.

def map(f, tree, *rest, is_leaf=None) -> Any:
    """
    Apply function to all leaves of one or more trees.
    
    Args:
        f: Function to apply to leaves
        tree: Primary PyTree  
        rest: Additional PyTrees with same structure
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        PyTree with same structure as input, f applied to all leaves
    """

def map_with_path(f, tree, *rest, is_leaf=None) -> Any:
    """
    Apply function to leaves with path information.
    
    Args:
        f: Function taking (path, *leaves) as arguments
        tree: Primary PyTree
        rest: Additional PyTrees with same structure  
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        PyTree with f applied to leaves, receiving path info
    """

def reduce(function, tree, initializer=None, is_leaf=None) -> Any:
    """
    Reduce tree to single value by applying function to all leaves.
    
    Args:
        function: Binary function to combine leaves  
        tree: PyTree to reduce
        initializer: Optional initial value for reduction
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        Single value from reducing all leaves
    """

def all(tree) -> bool:
    """
    Return True if all leaves are truthy.
    
    Args:
        tree: PyTree to check
        
    Returns:
        Boolean indicating if all leaves are truthy
    """

Usage examples:

# Apply function to all arrays in parameter tree
def init_weights(params):
    return jtree.map(lambda x: x * 0.01, params)

# Element-wise operations on multiple trees
def add_trees(tree1, tree2):
    return jtree.map(lambda x, y: x + y, tree1, tree2)

# Compute total number of parameters
def count_params(params):
    return jtree.reduce(lambda count, x: count + x.size, params, initializer=0)

# Check if all gradients are finite
def all_finite(grads):
    return jtree.all(jtree.map(jnp.isfinite, grads))

# Apply different functions based on path
def scale_by_path(path, param):
    if 'bias' in path:
        return param * 0.1  # Smaller learning rate for biases
    else:
        return param * 1.0

scaled_grads = jtree.map_with_path(scale_by_path, gradients)

Tree Structure Operations

Flatten trees into lists and reconstruct them, useful for interfacing with optimizers and other libraries.

def flatten(tree, is_leaf=None) -> tuple[list, Any]:
    """
    Flatten PyTree into list of leaves and tree definition.
    
    Args:
        tree: PyTree to flatten
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        Tuple of (leaves_list, tree_definition)
    """

def unflatten(treedef, leaves) -> Any:
    """
    Reconstruct PyTree from tree definition and leaves.
    
    Args:
        treedef: Tree definition from flatten()
        leaves: List of leaf values
        
    Returns:
        Reconstructed PyTree with original structure
    """

def flatten_with_path(tree, is_leaf=None) -> tuple[list, list]:
    """
    Flatten PyTree with path information for each leaf.
    
    Args:
        tree: PyTree to flatten
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        Tuple of (path_leaf_pairs, tree_definition)
    """

def leaves(tree, is_leaf=None) -> list:
    """
    Get list of all leaves in PyTree.
    
    Args:
        tree: PyTree to extract leaves from
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        List containing all leaf values
    """

def leaves_with_path(tree, is_leaf=None) -> list:
    """
    Get list of (path, leaf) pairs.
    
    Args:
        tree: PyTree to extract leaves from
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        List of (path, leaf) tuples
    """

def structure(tree, is_leaf=None) -> Any:
    """
    Get tree structure (definition) without leaf values.
    
    Args:
        tree: PyTree to get structure from
        is_leaf: Optional function to determine what counts as leaf
        
    Returns:
        Tree definition describing structure
    """

Usage examples:

# Flatten for use with scipy optimizers
params = {'w': jnp.array([1, 2]), 'b': jnp.array([3])}
flat_params, tree_def = jtree.flatten(params)
print(flat_params)  # [Array([1, 2]), Array([3])]

# Reconstruct after optimization
new_flat_params = [jnp.array([4, 5]), jnp.array([6])]
new_params = jtree.unflatten(tree_def, new_flat_params)
print(new_params)  # {'w': Array([4, 5]), 'b': Array([6])}

# Get all parameter arrays
all_arrays = jtree.leaves(params)

# Inspect structure with paths
path_leaf_pairs = jtree.leaves_with_path(params)  
print(path_leaf_pairs)  # [(('w',), Array([1, 2])), (('b',), Array([3]))]

# Get structure for later use
structure_only = jtree.structure(params)

Tree Transformation and Manipulation

Advanced operations for tree manipulation and structural transformations.

def transpose(outer_treedef, inner_treedef, pytree_to_transpose) -> Any:
    """
    Transpose nested PyTree structure.
    
    Args:
        outer_treedef: Target outer tree structure
        inner_treedef: Target inner tree structure
        pytree_to_transpose: PyTree to transpose
        
    Returns:
        PyTree with transposed nested structure
    """

Usage example:

# Transpose structure: list of dicts -> dict of lists
list_of_dicts = [
    {'a': 1, 'b': 2}, 
    {'a': 3, 'b': 4},
    {'a': 5, 'b': 6}
]

# Get structure definitions
outer_structure = jtree.structure(list_of_dicts)  # List structure
inner_structure = jtree.structure({'a': None, 'b': None})  # Dict structure

# Transpose to dict of lists
dict_of_lists = jtree.transpose(inner_structure, outer_structure, list_of_dicts)
print(dict_of_lists)  # {'a': [1, 3, 5], 'b': [2, 4, 6]}

Broadcasting and Advanced Operations

def broadcast(f, tree, *rest) -> Any:
    """
    Broadcast function application across PyTree structures.
    
    Args:
        f: Function to broadcast
        tree: Primary PyTree
        rest: Additional PyTrees (may have different but compatible structures)
        
    Returns:
        PyTree result of broadcasting f across inputs
    """

Custom PyTree Types

Register custom classes as PyTree nodes:

import jax

# Register custom class as PyTree node
class MyContainer:
    def __init__(self, data):
        self.data = data
        
    def __repr__(self):
        return f"MyContainer({self.data})"

def container_flatten(container):
    # Return (children, aux_data) where children are PyTrees
    return (container.data.values(), tuple(container.data.keys()))

def container_unflatten(aux_data, children):
    # Reconstruct from aux_data and children
    return MyContainer(dict(zip(aux_data, children)))

# Register the PyTree node
jax.tree_util.register_pytree_node(
    MyContainer,
    container_flatten,
    container_unflatten
)

# Now MyContainer works with tree operations
container = MyContainer({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
doubled = jtree.map(lambda x: x * 2, container)
print(doubled)  # MyContainer({'x': Array([2, 4]), 'y': Array([6, 8])})

Common Usage Patterns

Neural Network Parameter Management

# Initialize network parameters as PyTree  
def init_mlp_params(layer_sizes, key):
    params = {}
    keys = jax.random.split(key, len(layer_sizes) - 1)
    
    for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        w_key, b_key = jax.random.split(keys[i])
        params[f'layer_{i}'] = {
            'weights': jax.random.normal(w_key, (in_size, out_size)) * 0.01,
            'biases': jnp.zeros(out_size)
        }
    return params

# Apply gradients using tree operations
def update_params(params, grads, learning_rate):
    return jtree.map(lambda p, g: p - learning_rate * g, params, grads)

# Compute parameter statistics
def param_stats(params):
    flat_params = jtree.leaves(params)
    total_params = sum(p.size for p in flat_params)
    param_norm = jnp.sqrt(sum(jnp.sum(p**2) for p in flat_params))
    return {'total_params': total_params, 'norm': param_norm}

Optimizer State Management

# Adam optimizer state as PyTree
def init_adam_state(params):
    return {
        'm': jtree.map(jnp.zeros_like, params),  # First moment
        'v': jtree.map(jnp.zeros_like, params),  # Second moment  
        'step': 0
    }

def adam_update(params, grads, state, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
    step = state['step'] + 1
    
    # Update biased moments
    m = jtree.map(lambda m_prev, g: beta1 * m_prev + (1 - beta1) * g, state['m'], grads)
    v = jtree.map(lambda v_prev, g: beta2 * v_prev + (1 - beta2) * g**2, state['v'], grads)
    
    # Bias correction
    m_hat = jtree.map(lambda m_val: m_val / (1 - beta1**step), m)
    v_hat = jtree.map(lambda v_val: v_val / (1 - beta2**step), v)
    
    # Parameter update
    new_params = jtree.map(
        lambda p, m_val, v_val: p - learning_rate * m_val / (jnp.sqrt(v_val) + eps),
        params, m_hat, v_hat
    )
    
    new_state = {'m': m, 'v': v, 'step': step}
    return new_params, new_state

Batch Processing

# Process batch of PyTrees
def process_batch(batch_trees):
    # batch_trees is a list of PyTrees
    # Convert to PyTree of batched arrays
    return jtree.map(lambda *arrays: jnp.stack(arrays), *batch_trees)

# Example: batch of neural network inputs
batch_inputs = [
    {'image': jnp.ones((28, 28)), 'label': 5},
    {'image': jnp.zeros((28, 28)), 'label': 3},
    {'image': jnp.ones((28, 28)) * 0.5, 'label': 1}
]

batched = process_batch(batch_inputs)
print(batched['image'].shape)  # (3, 28, 28)
print(batched['label'].shape)  # (3,)

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json