CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

types.mddocs/

Type Definitions

Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.

Capabilities

Array Types

Core array type definitions for JAX and NumPy arrays.

# Base array types
ArrayNumpy = np.ndarray
ArrayDevice = jax.Array
ArraySharded = jax.Array  # Backward compatibility alias
ArrayBatched = jax.Array  # Backward compatibility alias

# Generic array type combining JAX and NumPy arrays
Array = Union[
    ArrayDevice,
    ArrayBatched, 
    ArraySharded,
    ArrayNumpy,
    np.bool_,
    np.number
]

Tree Types

Type definitions for JAX pytrees containing arrays.

# Tree of generic arrays
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]

# Tree of JAX device arrays  
ArrayDeviceTree = Union[
    ArrayDevice, 
    Iterable['ArrayDeviceTree'], 
    Mapping[Any, 'ArrayDeviceTree']
]

# Tree of NumPy arrays
ArrayNumpyTree = Union[
    ArrayNumpy,
    Iterable['ArrayNumpyTree'], 
    Mapping[Any, 'ArrayNumpyTree']
]

Scalar and Numeric Types

Type definitions for scalar values and numeric data.

# Scalar types
Scalar = Union[float, int]

# Combined numeric type including arrays and scalars
Numeric = Union[Array, Scalar]

Shape and Structure Types

Type definitions for array shapes and JAX structures.

# Shape type allowing flexible dimension specifications
Shape = Sequence[Union[int, Any]]

# JAX pytree definition type
PyTreeDef = jax.tree_util.PyTreeDef

Device and Hardware Types

Type definitions for JAX devices and hardware.

# JAX device type
Device = jax.Device

# PRNG key type for random number generation
PRNGKey = jax.Array

Data Type Definitions

Type definitions for array data types.

# Array dtype type (version-dependent)
ArrayDType = jax.typing.DTypeLike  # JAX 0.4.19+
# ArrayDType = Any  # Older JAX versions

Usage Examples

Type Annotations

import chex
import jax
import jax.numpy as jnp
from typing import Tuple, Optional

def process_batch(
    data: chex.Array,
    weights: chex.ArrayTree,
    batch_size: int
) -> chex.Array:
    """Process a batch of data with given weights."""
    chex.assert_shape(data, (batch_size, None))  # Flexible feature dimension
    return jnp.dot(data, weights['linear']) + weights['bias']

def compute_loss(
    predictions: chex.Array,
    targets: chex.Array
) -> chex.Scalar:
    """Compute scalar loss value."""
    return jnp.mean((predictions - targets) ** 2)

def create_model_state(
    params: chex.ArrayTree,
    optimizer_state: chex.ArrayTree,
    step: int,
    rng_key: chex.PRNGKey
) -> dict:
    """Create training state with proper types."""
    return {
        'params': params,
        'opt_state': optimizer_state, 
        'step': step,
        'rng': rng_key
    }

Shape Specifications

from typing import Callable

def linear_layer(
    inputs: chex.Array,  # Shape: (batch, input_dim)
    weights: chex.Array,  # Shape: (input_dim, output_dim) 
    bias: chex.Array     # Shape: (output_dim,)
) -> chex.Array:        # Shape: (batch, output_dim)
    """Linear transformation layer."""
    chex.assert_rank(inputs, 2)
    chex.assert_rank(weights, 2) 
    chex.assert_rank(bias, 1)
    
    return jnp.dot(inputs, weights) + bias

# Flexible shape specifications
def process_sequence(
    sequence: chex.Array,  # Shape: (seq_len, batch, features)
    mask: Optional[chex.Array] = None  # Shape: (seq_len, batch) or None
) -> chex.Array:  # Shape: (batch, features)
    """Process variable-length sequences."""
    seq_len, batch_size, features = sequence.shape
    
    if mask is not None:
        chex.assert_shape(mask, (seq_len, batch_size))
        sequence = sequence * mask[..., None]
    
    return jnp.mean(sequence, axis=0)  # Average over sequence length

Tree Type Usage

def initialize_model(
    key: chex.PRNGKey,
    input_shape: chex.Shape
) -> chex.ArrayTree:
    """Initialize model parameters as a tree structure."""
    
    keys = jax.random.split(key, 3)
    
    params = {
        'encoder': {
            'weights': jax.random.normal(keys[0], (input_shape[-1], 128)),
            'bias': jnp.zeros(128)
        },
        'decoder': {
            'weights': jax.random.normal(keys[1], (128, 10)),
            'bias': jnp.zeros(10)  
        },
        'scale': jax.random.uniform(keys[2], (), minval=0.5, maxval=1.5)
    }
    
    return params

def apply_model(
    params: chex.ArrayTree,
    inputs: chex.Array
) -> chex.Array:
    """Apply model with tree-structured parameters."""
    
    # Encoder
    hidden = jnp.dot(inputs, params['encoder']['weights'])
    hidden = hidden + params['encoder']['bias']
    hidden = jax.nn.relu(hidden)
    
    # Decoder  
    outputs = jnp.dot(hidden, params['decoder']['weights'])
    outputs = outputs + params['decoder']['bias']
    
    # Apply global scale
    outputs = outputs * params['scale']
    
    return outputs

def tree_statistics(tree: chex.ArrayTree) -> dict:
    """Compute statistics over a tree of arrays."""
    
    def compute_stats(array: chex.Array) -> dict:
        return {
            'mean': jnp.mean(array),
            'std': jnp.std(array),
            'shape': array.shape
        }
    
    return jax.tree_map(compute_stats, tree)

Device Type Usage

def distribute_computation(
    data: chex.Array,
    devices: list[chex.Device]
) -> chex.Array:
    """Distribute computation across multiple devices."""
    
    n_devices = len(devices)
    batch_size = data.shape[0]
    
    # Ensure data can be evenly split
    chex.assert_is_divisible(batch_size, n_devices)
    
    # Split data across devices
    per_device_size = batch_size // n_devices
    split_data = data.reshape(n_devices, per_device_size, *data.shape[1:])
    
    # Process on each device
    def process_shard(shard):
        return jnp.sum(shard, axis=0)
    
    # Map across devices
    results = jax.pmap(process_shard)(split_data)
    
    return results

def check_device_placement(
    array: chex.Array,
    expected_device: chex.Device
) -> bool:
    """Check if array is placed on expected device."""
    if hasattr(array, 'device'):
        return array.device == expected_device
    return True  # NumPy arrays don't have device placement

Numeric Type Usage

def safe_divide(
    numerator: chex.Numeric,
    denominator: chex.Numeric,
    epsilon: float = 1e-8
) -> chex.Numeric:
    """Safely divide numeric values with epsilon."""
    
    # Handle both scalar and array inputs
    if isinstance(denominator, (int, float)):
        safe_denom = denominator + epsilon if denominator == 0 else denominator
    else:
        safe_denom = jnp.where(
            jnp.abs(denominator) < epsilon,
            epsilon,
            denominator
        )
    
    return numerator / safe_denom

def normalize_features(
    features: chex.Array,
    axis: Optional[int] = None
) -> Tuple[chex.Array, chex.Scalar]:
    """Normalize features and return normalization constant."""
    
    # Compute normalization factor
    norm: chex.Scalar = jnp.linalg.norm(features, axis=axis, keepdims=True)
    
    # Normalize 
    normalized = safe_divide(features, norm)
    
    return normalized, jnp.squeeze(norm)

Generic Type Functions

from typing import TypeVar, Callable

T = TypeVar('T', bound=chex.ArrayTree)

def apply_tree_function(
    tree: T,
    fn: Callable[[chex.Array], chex.Array]
) -> T:
    """Apply function to all arrays in tree, preserving structure."""
    return jax.tree_map(fn, tree)

def validate_tree_structure(
    tree1: chex.ArrayTree,
    tree2: chex.ArrayTree
) -> bool:
    """Validate that two trees have the same structure."""
    try:
        jax.tree_map(lambda x, y: None, tree1, tree2)
        return True
    except (TypeError, ValueError):
        return False

def convert_tree_dtype(
    tree: chex.ArrayTree,
    dtype: chex.ArrayDType
) -> chex.ArrayTree:
    """Convert all arrays in tree to specified dtype."""
    return jax.tree_map(lambda x: x.astype(dtype), tree)

Type Compatibility

JAX Integration

All Chex types are designed for seamless integration with JAX:

  • Array types work with all JAX transformations
  • Tree types support JAX pytree operations
  • Shape types enable flexible dimension handling
  • Device types support multi-device computation

NumPy Compatibility

Chex types maintain NumPy compatibility:

  • Array types include NumPy arrays
  • Scalar types work with NumPy operations
  • Shape specifications support NumPy broadcasting

Version Compatibility

Type definitions adapt to JAX version differences:

  • ArrayDType uses JAX's DTypeLike when available
  • Backward compatibility aliases for deprecated types
  • Future-proof type specifications

Best Practices

Use Specific Types

# Good: Specific type information
def process_images(images: chex.Array) -> chex.Array:
    chex.assert_rank(images, 4)  # (batch, height, width, channels)
    return images

# Better: Include shape information in docstring
def process_images(images: chex.Array) -> chex.Array:
    """Process batch of images.
    
    Args:
        images: Array of shape (batch, height, width, channels)
    
    Returns:
        Processed images of same shape
    """

Combine with Assertions

def typed_function(
    data: chex.Array,
    weights: chex.ArrayTree
) -> chex.Array:
    # Runtime validation matches type annotations
    chex.assert_type(data, chex.Array)
    chex.assert_tree_has_only_ndarrays(weights)
    
    return process_data(data, weights)

Document Shape Expectations

def attention_layer(
    query: chex.Array,    # (batch, seq_q, dim)
    key: chex.Array,      # (batch, seq_k, dim)  
    value: chex.Array     # (batch, seq_k, dim)
) -> chex.Array:         # (batch, seq_q, dim)
    """Multi-head attention with clear shape specifications."""
    pass

Install with Tessl CLI

npx tessl i tessl/pypi-chex

docs

advanced.md

assertions.md

dataclasses.md

debugging.md

index.md

testing.md

types.md

tile.json