CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-etils

Collection of common python utils for machine learning and scientific computing workflows

Pending
Overview
Eval results
Files

tree-manipulation.mddocs/

Tree Manipulation (etree)

Universal tree manipulation utilities compatible with TensorFlow nest, JAX tree_utils, DeepMind tree, and pure Python data structures. Provides a unified API for working with nested data structures across different ML frameworks.

Capabilities

Core Tree Type

Type definition for tree structures.

Tree = Any  # Nested data structure (dict, list, tuple, or custom)
LeafFn = Callable[[Any], bool]  # Function to determine what constitutes a leaf
TreeDef = Any  # Tree structure definition from flatten operations

Tree API Objects

Different backend implementations for tree operations.

jax: TreeAPI       # JAX tree operations backend
optree: TreeAPI    # Optree backend
tree: TreeAPI      # DeepMind tree backend  
nest: TreeAPI      # TensorFlow nest backend
py: TreeAPI        # Pure Python backend (default)

Core Tree Operations

The py API provides the primary tree manipulation functions.

def map(
    map_fn: Callable[..., Any], 
    *trees: Tree,
    is_leaf: Optional[LeafFn] = None
) -> Tree:
    """
    Apply function to all leaf values in tree structures.
    
    Args:
        map_fn: Function to apply to each leaf or set of leaves
        *trees: Input tree structures (supports multiple trees)
        is_leaf: Function to determine what constitutes a leaf
        
    Returns:
        Tree with function applied to all leaves
    """

def parallel_map(
    map_fn: Callable[..., Any], 
    *trees: Tree,
    num_threads: Optional[int] = None,
    progress_bar: bool = False,
    is_leaf: Optional[LeafFn] = None
) -> Tree:
    """
    Apply function to all leaf values in parallel.
    
    Args:
        map_fn: Function to apply to each leaf or set of leaves
        *trees: Input tree structures (supports multiple trees)
        num_threads: Number of parallel threads to use
        progress_bar: Whether to display a progress bar
        is_leaf: Function to determine what constitutes a leaf
        
    Returns:
        Tree with function applied to all leaves in parallel
    """

def unzip(tree: Tree) -> Tree:
    """
    Unzip a tree of tuples/lists into a tuple/list of trees.
    
    Args:
        tree: Tree containing tuples or lists
        
    Returns:
        Tuple/list of trees
    """

def stack(tree: Tree) -> Tree:
    """
    Stack multiple trees into a single tree.
    
    Args:
        tree: Tree containing stackable elements
        
    Returns:
        Stacked tree structure
    """

def spec_like(
    tree: Tree,
    *,
    ignore_other: bool = True
) -> Tree:
    """
    Create a spec-like structure matching the tree shape.
    
    Args:
        tree: Input tree structure
        ignore_other: Whether to ignore non-array types
        
    Returns:
        Spec structure matching input tree
    """

def copy(tree: Tree) -> Tree:
    """
    Create a deep copy of the tree structure.
    
    Args:
        tree: Input tree structure
        
    Returns:
        Deep copy of the tree
    """

# Backend-specific methods (available via backend attribute)
def flatten(tree: Tree, *, is_leaf: Optional[LeafFn] = None) -> tuple[list, TreeDef]:
    """
    Flatten a tree structure into a list of leaves and structure definition.
    
    Args:
        tree: Input tree structure
        is_leaf: Function to determine what constitutes a leaf
        
    Returns:
        Tuple of (flat_sequence, tree_structure)
    """

def unflatten(structure: TreeDef, flat_sequence: list) -> Tree:
    """
    Reconstruct a tree from flattened data and structure.
    
    Args:
        structure: Tree structure definition from flatten()
        flat_sequence: Flattened list of leaf values
        
    Returns:
        Reconstructed tree structure
    """

def assert_same_structure(tree0: Tree, tree1: Tree) -> None:
    """
    Assert that two trees have the same structure.
    
    Args:
        tree0: First tree
        tree1: Second tree
        
    Raises:
        ValueError: If structures don't match
    """

Backend Modules

Access to underlying backend implementations.

backend: ModuleType     # Backend implementations module
tree_utils: ModuleType  # Core tree utility functions module

Usage Examples

Basic Tree Operations

from etils import etree

# Define a nested data structure
data = {
    'params': {
        'weights': [[1.0, 2.0], [3.0, 4.0]],
        'bias': [0.1, 0.2]
    },
    'config': {
        'learning_rate': 0.01,
        'batch_size': 32
    }
}

# Apply function to all numeric values
doubled = etree.py.map(lambda x: x * 2 if isinstance(x, (int, float)) else x, data)
# Result: All numeric values doubled

# Deep copy the structure
data_copy = etree.py.copy(data)

Working with Multiple Trees

from etils import etree

# Multiple parameter sets
tree1 = {'a': [1, 2], 'b': {'c': 3}}
tree2 = {'a': [4, 5], 'b': {'c': 6}}

# Combine operations across trees
combined = etree.py.map(lambda x, y: x + y, tree1, tree2)
# Result: {'a': [5, 7], 'b': {'c': 9}}

Framework Compatibility

from etils import etree
import jax
import tensorflow as tf

# JAX compatibility
jax_tree = {'params': jax.numpy.array([1, 2, 3])}
processed_jax = etree.jax.map(lambda x: x * 2, jax_tree)

# TensorFlow compatibility  
tf_tree = {'weights': tf.constant([1.0, 2.0, 3.0])}
processed_tf = etree.nest.map(lambda x: x * 2, tf_tree)

# Pure Python (default)
py_tree = {'data': [1, 2, 3]}
processed_py = etree.py.map(lambda x: x * 2, py_tree)

Advanced Tree Operations

from etils import etree

# Unzip paired data
paired_data = {
    'train': [(x1, y1), (x2, y2), (x3, y3)],
    'test': [(x4, y4), (x5, y5)]
}
x_data, y_data = etree.py.unzip(paired_data)

# Stack multiple examples
examples = [
    {'features': [1, 2], 'label': 0},
    {'features': [3, 4], 'label': 1},
    {'features': [5, 6], 'label': 0}
]
batched = etree.py.stack(examples)
# Result: {'features': [[1,2], [3,4], [5,6]], 'label': [0, 1, 0]}

# Create spec structure
spec = etree.py.spec_like(data)
# Result: Structure matching data but with spec information

Parallel Processing

from etils import etree
import numpy as np

# Large data structure with expensive operations
large_data = {
    'layer1': {'weights': np.random.rand(1000, 1000)},
    'layer2': {'weights': np.random.rand(1000, 1000)},
    'layer3': {'weights': np.random.rand(1000, 1000)}
}

# Expensive function (e.g., matrix operations)
def expensive_op(x):
    return np.linalg.svd(x)[0]  # SVD decomposition

# Apply in parallel for better performance
result = etree.py.parallel_map(expensive_op, large_data)

Install with Tessl CLI

npx tessl i tessl/pypi-etils

docs

application-framework.md

array-types.md

colab-integration.md

dataclass-enhancements.md

index.md

numpy-utilities.md

path-operations.md

python-utilities.md

tree-manipulation.md

tile.json