Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.
Core array type definitions for JAX and NumPy arrays.
# Base array types
ArrayNumpy = np.ndarray
ArrayDevice = jax.Array
ArraySharded = jax.Array # Backward compatibility alias
ArrayBatched = jax.Array # Backward compatibility alias
# Generic array type combining JAX and NumPy arrays
Array = Union[
ArrayDevice,
ArrayBatched,
ArraySharded,
ArrayNumpy,
np.bool_,
np.number
]Type definitions for JAX pytrees containing arrays.
# Tree of generic arrays
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
# Tree of JAX device arrays
ArrayDeviceTree = Union[
ArrayDevice,
Iterable['ArrayDeviceTree'],
Mapping[Any, 'ArrayDeviceTree']
]
# Tree of NumPy arrays
ArrayNumpyTree = Union[
ArrayNumpy,
Iterable['ArrayNumpyTree'],
Mapping[Any, 'ArrayNumpyTree']
]Type definitions for scalar values and numeric data.
# Scalar types
Scalar = Union[float, int]
# Combined numeric type including arrays and scalars
Numeric = Union[Array, Scalar]Type definitions for array shapes and JAX structures.
# Shape type allowing flexible dimension specifications
Shape = Sequence[Union[int, Any]]
# JAX pytree definition type
PyTreeDef = jax.tree_util.PyTreeDefType definitions for JAX devices and hardware.
# JAX device type
Device = jax.Device
# PRNG key type for random number generation
PRNGKey = jax.ArrayType definitions for array data types.
# Array dtype type (version-dependent)
ArrayDType = jax.typing.DTypeLike # JAX 0.4.19+
# ArrayDType = Any # Older JAX versionsimport chex
import jax
import jax.numpy as jnp
from typing import Tuple, Optional
def process_batch(
data: chex.Array,
weights: chex.ArrayTree,
batch_size: int
) -> chex.Array:
"""Process a batch of data with given weights."""
chex.assert_shape(data, (batch_size, None)) # Flexible feature dimension
return jnp.dot(data, weights['linear']) + weights['bias']
def compute_loss(
predictions: chex.Array,
targets: chex.Array
) -> chex.Scalar:
"""Compute scalar loss value."""
return jnp.mean((predictions - targets) ** 2)
def create_model_state(
params: chex.ArrayTree,
optimizer_state: chex.ArrayTree,
step: int,
rng_key: chex.PRNGKey
) -> dict:
"""Create training state with proper types."""
return {
'params': params,
'opt_state': optimizer_state,
'step': step,
'rng': rng_key
}from typing import Callable
def linear_layer(
inputs: chex.Array, # Shape: (batch, input_dim)
weights: chex.Array, # Shape: (input_dim, output_dim)
bias: chex.Array # Shape: (output_dim,)
) -> chex.Array: # Shape: (batch, output_dim)
"""Linear transformation layer."""
chex.assert_rank(inputs, 2)
chex.assert_rank(weights, 2)
chex.assert_rank(bias, 1)
return jnp.dot(inputs, weights) + bias
# Flexible shape specifications
def process_sequence(
sequence: chex.Array, # Shape: (seq_len, batch, features)
mask: Optional[chex.Array] = None # Shape: (seq_len, batch) or None
) -> chex.Array: # Shape: (batch, features)
"""Process variable-length sequences."""
seq_len, batch_size, features = sequence.shape
if mask is not None:
chex.assert_shape(mask, (seq_len, batch_size))
sequence = sequence * mask[..., None]
return jnp.mean(sequence, axis=0) # Average over sequence lengthdef initialize_model(
key: chex.PRNGKey,
input_shape: chex.Shape
) -> chex.ArrayTree:
"""Initialize model parameters as a tree structure."""
keys = jax.random.split(key, 3)
params = {
'encoder': {
'weights': jax.random.normal(keys[0], (input_shape[-1], 128)),
'bias': jnp.zeros(128)
},
'decoder': {
'weights': jax.random.normal(keys[1], (128, 10)),
'bias': jnp.zeros(10)
},
'scale': jax.random.uniform(keys[2], (), minval=0.5, maxval=1.5)
}
return params
def apply_model(
params: chex.ArrayTree,
inputs: chex.Array
) -> chex.Array:
"""Apply model with tree-structured parameters."""
# Encoder
hidden = jnp.dot(inputs, params['encoder']['weights'])
hidden = hidden + params['encoder']['bias']
hidden = jax.nn.relu(hidden)
# Decoder
outputs = jnp.dot(hidden, params['decoder']['weights'])
outputs = outputs + params['decoder']['bias']
# Apply global scale
outputs = outputs * params['scale']
return outputs
def tree_statistics(tree: chex.ArrayTree) -> dict:
"""Compute statistics over a tree of arrays."""
def compute_stats(array: chex.Array) -> dict:
return {
'mean': jnp.mean(array),
'std': jnp.std(array),
'shape': array.shape
}
return jax.tree_map(compute_stats, tree)def distribute_computation(
data: chex.Array,
devices: list[chex.Device]
) -> chex.Array:
"""Distribute computation across multiple devices."""
n_devices = len(devices)
batch_size = data.shape[0]
# Ensure data can be evenly split
chex.assert_is_divisible(batch_size, n_devices)
# Split data across devices
per_device_size = batch_size // n_devices
split_data = data.reshape(n_devices, per_device_size, *data.shape[1:])
# Process on each device
def process_shard(shard):
return jnp.sum(shard, axis=0)
# Map across devices
results = jax.pmap(process_shard)(split_data)
return results
def check_device_placement(
array: chex.Array,
expected_device: chex.Device
) -> bool:
"""Check if array is placed on expected device."""
if hasattr(array, 'device'):
return array.device == expected_device
return True # NumPy arrays don't have device placementdef safe_divide(
numerator: chex.Numeric,
denominator: chex.Numeric,
epsilon: float = 1e-8
) -> chex.Numeric:
"""Safely divide numeric values with epsilon."""
# Handle both scalar and array inputs
if isinstance(denominator, (int, float)):
safe_denom = denominator + epsilon if denominator == 0 else denominator
else:
safe_denom = jnp.where(
jnp.abs(denominator) < epsilon,
epsilon,
denominator
)
return numerator / safe_denom
def normalize_features(
features: chex.Array,
axis: Optional[int] = None
) -> Tuple[chex.Array, chex.Scalar]:
"""Normalize features and return normalization constant."""
# Compute normalization factor
norm: chex.Scalar = jnp.linalg.norm(features, axis=axis, keepdims=True)
# Normalize
normalized = safe_divide(features, norm)
return normalized, jnp.squeeze(norm)from typing import TypeVar, Callable
T = TypeVar('T', bound=chex.ArrayTree)
def apply_tree_function(
tree: T,
fn: Callable[[chex.Array], chex.Array]
) -> T:
"""Apply function to all arrays in tree, preserving structure."""
return jax.tree_map(fn, tree)
def validate_tree_structure(
tree1: chex.ArrayTree,
tree2: chex.ArrayTree
) -> bool:
"""Validate that two trees have the same structure."""
try:
jax.tree_map(lambda x, y: None, tree1, tree2)
return True
except (TypeError, ValueError):
return False
def convert_tree_dtype(
tree: chex.ArrayTree,
dtype: chex.ArrayDType
) -> chex.ArrayTree:
"""Convert all arrays in tree to specified dtype."""
return jax.tree_map(lambda x: x.astype(dtype), tree)All Chex types are designed for seamless integration with JAX:
Chex types maintain NumPy compatibility:
Type definitions adapt to JAX version differences:
# Good: Specific type information
def process_images(images: chex.Array) -> chex.Array:
chex.assert_rank(images, 4) # (batch, height, width, channels)
return images
# Better: Include shape information in docstring
def process_images(images: chex.Array) -> chex.Array:
"""Process batch of images.
Args:
images: Array of shape (batch, height, width, channels)
Returns:
Processed images of same shape
"""def typed_function(
data: chex.Array,
weights: chex.ArrayTree
) -> chex.Array:
# Runtime validation matches type annotations
chex.assert_type(data, chex.Array)
chex.assert_tree_has_only_ndarrays(weights)
return process_data(data, weights)def attention_layer(
query: chex.Array, # (batch, seq_q, dim)
key: chex.Array, # (batch, seq_k, dim)
value: chex.Array # (batch, seq_k, dim)
) -> chex.Array: # (batch, seq_q, dim)
"""Multi-head attention with clear shape specifications."""
passInstall with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10