Collection of common python utils for machine learning and scientific computing workflows
—
Universal tree manipulation utilities compatible with TensorFlow nest, JAX tree_utils, DeepMind tree, and pure Python data structures. Provides a unified API for working with nested data structures across different ML frameworks.
Type definition for tree structures.
Tree = Any # Nested data structure (dict, list, tuple, or custom)
LeafFn = Callable[[Any], bool] # Function to determine what constitutes a leaf
TreeDef = Any # Tree structure definition from flatten operationsDifferent backend implementations for tree operations.
jax: TreeAPI # JAX tree operations backend
optree: TreeAPI # Optree backend
tree: TreeAPI # DeepMind tree backend
nest: TreeAPI # TensorFlow nest backend
py: TreeAPI # Pure Python backend (default)The py API provides the primary tree manipulation functions.
def map(
map_fn: Callable[..., Any],
*trees: Tree,
is_leaf: Optional[LeafFn] = None
) -> Tree:
"""
Apply function to all leaf values in tree structures.
Args:
map_fn: Function to apply to each leaf or set of leaves
*trees: Input tree structures (supports multiple trees)
is_leaf: Function to determine what constitutes a leaf
Returns:
Tree with function applied to all leaves
"""
def parallel_map(
map_fn: Callable[..., Any],
*trees: Tree,
num_threads: Optional[int] = None,
progress_bar: bool = False,
is_leaf: Optional[LeafFn] = None
) -> Tree:
"""
Apply function to all leaf values in parallel.
Args:
map_fn: Function to apply to each leaf or set of leaves
*trees: Input tree structures (supports multiple trees)
num_threads: Number of parallel threads to use
progress_bar: Whether to display a progress bar
is_leaf: Function to determine what constitutes a leaf
Returns:
Tree with function applied to all leaves in parallel
"""
def unzip(tree: Tree) -> Tree:
"""
Unzip a tree of tuples/lists into a tuple/list of trees.
Args:
tree: Tree containing tuples or lists
Returns:
Tuple/list of trees
"""
def stack(tree: Tree) -> Tree:
"""
Stack multiple trees into a single tree.
Args:
tree: Tree containing stackable elements
Returns:
Stacked tree structure
"""
def spec_like(
tree: Tree,
*,
ignore_other: bool = True
) -> Tree:
"""
Create a spec-like structure matching the tree shape.
Args:
tree: Input tree structure
ignore_other: Whether to ignore non-array types
Returns:
Spec structure matching input tree
"""
def copy(tree: Tree) -> Tree:
"""
Create a deep copy of the tree structure.
Args:
tree: Input tree structure
Returns:
Deep copy of the tree
"""
# Backend-specific methods (available via backend attribute)
def flatten(tree: Tree, *, is_leaf: Optional[LeafFn] = None) -> tuple[list, TreeDef]:
"""
Flatten a tree structure into a list of leaves and structure definition.
Args:
tree: Input tree structure
is_leaf: Function to determine what constitutes a leaf
Returns:
Tuple of (flat_sequence, tree_structure)
"""
def unflatten(structure: TreeDef, flat_sequence: list) -> Tree:
"""
Reconstruct a tree from flattened data and structure.
Args:
structure: Tree structure definition from flatten()
flat_sequence: Flattened list of leaf values
Returns:
Reconstructed tree structure
"""
def assert_same_structure(tree0: Tree, tree1: Tree) -> None:
"""
Assert that two trees have the same structure.
Args:
tree0: First tree
tree1: Second tree
Raises:
ValueError: If structures don't match
"""Access to underlying backend implementations.
backend: ModuleType # Backend implementations module
tree_utils: ModuleType # Core tree utility functions modulefrom etils import etree
# Define a nested data structure
data = {
'params': {
'weights': [[1.0, 2.0], [3.0, 4.0]],
'bias': [0.1, 0.2]
},
'config': {
'learning_rate': 0.01,
'batch_size': 32
}
}
# Apply function to all numeric values
doubled = etree.py.map(lambda x: x * 2 if isinstance(x, (int, float)) else x, data)
# Result: All numeric values doubled
# Deep copy the structure
data_copy = etree.py.copy(data)from etils import etree
# Multiple parameter sets
tree1 = {'a': [1, 2], 'b': {'c': 3}}
tree2 = {'a': [4, 5], 'b': {'c': 6}}
# Combine operations across trees
combined = etree.py.map(lambda x, y: x + y, tree1, tree2)
# Result: {'a': [5, 7], 'b': {'c': 9}}from etils import etree
import jax
import tensorflow as tf
# JAX compatibility
jax_tree = {'params': jax.numpy.array([1, 2, 3])}
processed_jax = etree.jax.map(lambda x: x * 2, jax_tree)
# TensorFlow compatibility
tf_tree = {'weights': tf.constant([1.0, 2.0, 3.0])}
processed_tf = etree.nest.map(lambda x: x * 2, tf_tree)
# Pure Python (default)
py_tree = {'data': [1, 2, 3]}
processed_py = etree.py.map(lambda x: x * 2, py_tree)from etils import etree
# Unzip paired data
paired_data = {
'train': [(x1, y1), (x2, y2), (x3, y3)],
'test': [(x4, y4), (x5, y5)]
}
x_data, y_data = etree.py.unzip(paired_data)
# Stack multiple examples
examples = [
{'features': [1, 2], 'label': 0},
{'features': [3, 4], 'label': 1},
{'features': [5, 6], 'label': 0}
]
batched = etree.py.stack(examples)
# Result: {'features': [[1,2], [3,4], [5,6]], 'label': [0, 1, 0]}
# Create spec structure
spec = etree.py.spec_like(data)
# Result: Structure matching data but with spec informationfrom etils import etree
import numpy as np
# Large data structure with expensive operations
large_data = {
'layer1': {'weights': np.random.rand(1000, 1000)},
'layer2': {'weights': np.random.rand(1000, 1000)},
'layer3': {'weights': np.random.rand(1000, 1000)}
}
# Expensive function (e.g., matrix operations)
def expensive_op(x):
return np.linalg.svd(x)[0] # SVD decomposition
# Apply in parallel for better performance
result = etree.py.parallel_map(expensive_op, large_data)Install with Tessl CLI
npx tessl i tessl/pypi-etils