A gradient processing and optimization library in JAX
—
JAX PyTree manipulation utilities for working with nested parameter structures. These functions provide efficient operations on tree-structured data, which is common in neural network parameters and gradients.
def tree_cast(tree, dtype):
"""Cast all leaves of a tree to specified dtype."""
def tree_cast_like(tree, target_tree):
"""Cast tree leaves to match dtypes of target tree."""
def tree_dtype(tree):
"""Get dtype information for tree leaves."""
def tree_get(tree, path):
"""Get value at specified path in tree."""
def tree_set(tree, path, value):
"""Set value at specified path in tree."""
def tree_get_all_with_path(tree):
"""Get all (path, value) pairs from tree."""def tree_add(tree_a, tree_b):
"""Element-wise addition of two trees."""
def tree_sub(tree_a, tree_b):
"""Element-wise subtraction of two trees."""
def tree_mul(tree_a, tree_b):
"""Element-wise multiplication of two trees."""
def tree_div(tree_a, tree_b):
"""Element-wise division of two trees."""
def tree_add_scale(tree, scalar, scaled_tree):
"""Compute tree + scalar * scaled_tree."""
def tree_scale(tree, scalar):
"""Scale all leaves of tree by scalar."""def tree_sum(tree):
"""Sum all leaves in tree."""
def tree_max(tree):
"""Maximum value across all leaves."""
def tree_vdot(tree_a, tree_b):
"""Vector dot product of flattened trees."""
def tree_batch_shape(tree):
"""Get batch shape from tree structure."""def tree_zeros_like(tree):
"""Create tree of zeros with same structure."""
def tree_ones_like(tree):
"""Create tree of ones with same structure."""
def tree_full_like(tree, fill_value):
"""Create tree filled with specified value."""
def tree_random_like(tree, key):
"""Create tree of random values with same structure."""def tree_real(tree):
"""Extract real parts of complex tree."""
def tree_conj(tree):
"""Complex conjugate of tree."""
def tree_where(condition, tree_a, tree_b):
"""Element-wise selection between trees."""import optax
import jax.numpy as jnp
import jax
# Example tree structure (neural network parameters)
params = {
'dense1': {'weights': jnp.ones((10, 5)), 'bias': jnp.zeros(5)},
'dense2': {'weights': jnp.ones((5, 1)), 'bias': jnp.zeros(1)}
}
# Tree arithmetic operations
grads = optax.tree.tree_zeros_like(params)
scaled_grads = optax.tree.tree_scale(grads, 0.01)
updated_params = optax.tree.tree_sub(params, scaled_grads)
# Tree reductions
total_params = optax.tree.tree_sum(
optax.tree.tree_map_params(lambda x: x.size, params)
)
param_norm = jnp.sqrt(optax.tree.tree_vdot(params, params))
# Random initialization
key = jax.random.PRNGKey(42)
random_params = optax.tree.tree_random_like(params, key)
# Working with paths
weight_path = ('dense1', 'weights')
weights = optax.tree.tree_get(params, weight_path)
new_params = optax.tree.tree_set(params, weight_path, weights * 0.9)import optax.tree
# or
from optax.tree import tree_add, tree_scale, tree_zeros_likeInstall with Tessl CLI
npx tessl i tessl/pypi-optax