CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-optax

A gradient processing and optimization library in JAX

Pending
Overview
Eval results
Files

tree-utilities.mddocs/

Tree Utilities

JAX PyTree manipulation utilities for working with nested parameter structures. These functions provide efficient operations on tree-structured data, which is common in neural network parameters and gradients.

Capabilities

Tree Manipulation

def tree_cast(tree, dtype):
    """Cast all leaves of a tree to specified dtype."""

def tree_cast_like(tree, target_tree):
    """Cast tree leaves to match dtypes of target tree."""

def tree_dtype(tree):
    """Get dtype information for tree leaves."""

def tree_get(tree, path):
    """Get value at specified path in tree."""

def tree_set(tree, path, value):
    """Set value at specified path in tree."""

def tree_get_all_with_path(tree):
    """Get all (path, value) pairs from tree."""

Tree Arithmetic

def tree_add(tree_a, tree_b):
    """Element-wise addition of two trees."""

def tree_sub(tree_a, tree_b):
    """Element-wise subtraction of two trees."""

def tree_mul(tree_a, tree_b):
    """Element-wise multiplication of two trees."""

def tree_div(tree_a, tree_b):
    """Element-wise division of two trees."""

def tree_add_scale(tree, scalar, scaled_tree):
    """Compute tree + scalar * scaled_tree."""

def tree_scale(tree, scalar):
    """Scale all leaves of tree by scalar."""

Tree Reductions

def tree_sum(tree):
    """Sum all leaves in tree."""

def tree_max(tree):
    """Maximum value across all leaves."""

def tree_vdot(tree_a, tree_b):
    """Vector dot product of flattened trees."""

def tree_batch_shape(tree):
    """Get batch shape from tree structure."""

Tree Creation

def tree_zeros_like(tree):
    """Create tree of zeros with same structure."""

def tree_ones_like(tree):
    """Create tree of ones with same structure."""

def tree_full_like(tree, fill_value):
    """Create tree filled with specified value."""

def tree_random_like(tree, key):
    """Create tree of random values with same structure."""

Complex Number Support

def tree_real(tree):
    """Extract real parts of complex tree."""

def tree_conj(tree):
    """Complex conjugate of tree."""

def tree_where(condition, tree_a, tree_b):
    """Element-wise selection between trees."""

Usage Examples

import optax
import jax.numpy as jnp
import jax

# Example tree structure (neural network parameters)
params = {
    'dense1': {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros(5)},
    'dense2': {'weights': jnp.ones((5, 1)), 'bias': jnp.zeros(1)}
}

# Tree arithmetic operations
grads = optax.tree.tree_zeros_like(params)
scaled_grads = optax.tree.tree_scale(grads, 0.01)
updated_params = optax.tree.tree_sub(params, scaled_grads)

# Tree reductions
total_params = optax.tree.tree_sum(
    optax.tree.tree_map_params(lambda x: x.size, params)
)
param_norm = jnp.sqrt(optax.tree.tree_vdot(params, params))

# Random initialization
key = jax.random.PRNGKey(42)
random_params = optax.tree.tree_random_like(params, key)

# Working with paths
weight_path = ('dense1', 'weights')
weights = optax.tree.tree_get(params, weight_path)
new_params = optax.tree.tree_set(params, weight_path, weights * 0.9)

Import

import optax.tree
# or
from optax.tree import tree_add, tree_scale, tree_zeros_like

Install with Tessl CLI

npx tessl i tessl/pypi-optax

docs

advanced-optimizers.md

assignment.md

contrib.md

index.md

losses.md

monte-carlo.md

optimizers.md

perturbations.md

projections.md

schedules.md

second-order.md

transformations.md

tree-utilities.md

utilities.md

tile.json