CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

advanced.mddocs/

Advanced Features

Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.

Capabilities

Backend Restriction

Context manager for controlling JAX backend compilation and device usage.

def restrict_backends(*, allowed=None, forbidden=None):
    """
    Context manager that prevents JAX compilation for specified backends.
    
    Useful for ensuring code runs only on intended devices or catching
    accidental compilation on restricted hardware.
    
    Parameters:
    - allowed: Sequence of allowed backend platform names (e.g., ['cpu', 'gpu'])
    - forbidden: Sequence of forbidden backend platform names
    
    Yields:
    - Context where compilation for forbidden platforms raises RestrictedBackendError
    
    Raises:
    - ValueError: If neither allowed nor forbidden specified, or if conflicts exist
    - RestrictedBackendError: If compilation attempted on restricted backend
    """

class RestrictedBackendError(RuntimeError):
    """
    Exception raised when compilation attempted on restricted backend.
    """

Dimension Mapping

Utility class for managing named dimensions and shape specifications.

class Dimensions:
    """
    Lightweight utility that maps strings to shape tuples.
    
    Enables readable shape specifications using named dimensions
    and supports dimension arithmetic and wildcard dimensions.
    
    Examples:
    >>> dims = chex.Dimensions(B=3, T=5, N=7)
    >>> dims['NBT']  # (7, 3, 5)
    >>> dims['(BT)N']  # (15, 7) - flattened dimensions
    >>> dims['BT*']  # (3, 5, None) - wildcard dimension
    """
    
    def __init__(self, **kwargs):
        """
        Initialize dimensions with named size mappings.
        
        Parameters:
        - **kwargs: Dimension name to size mappings (e.g., B=32, T=100)
        """
    
    def __getitem__(self, key):
        """
        Get shape tuple for dimension string specification.
        
        Parameters:
        - key: String specifying dimensions (e.g., 'BTC', '(BT)C', 'BT*')
        
        Returns:
        - Tuple of integers and/or None for wildcard dimensions
        """
    
    def __setitem__(self, key, value):
        """
        Set dimension sizes from shape tuple.
        
        Parameters:
        - key: String specifying dimensions
        - value: Shape tuple to assign to dimensions
        """
    
    def size(self, key):
        """
        Get total size (product) of specified dimensions.
        
        Parameters:
        - key: String specifying dimensions
        
        Returns:
        - Total number of elements in the specified shape
        """

Jittable Assertions

Advanced assertion system that works inside jitted functions using JAX checkify.

def chexify(
    fn, 
    async_check=True,
    errors=ChexifyChecks.user
):
    """
    Enable Chex value assertions inside jitted functions.
    
    Wraps function to enable runtime assertions that work with JAX transformations
    by using JAX's checkify system for delayed error checking.
    
    Parameters:
    - fn: Function to wrap with jittable assertions
    - async_check: Whether to check errors asynchronously  
    - errors: Set of error categories to check (from ChexifyChecks)
    
    Returns:
    - Wrapped function that supports Chex assertions inside jit
    """

def with_jittable_assertions(fn):
    """
    Decorator for enabling jittable assertions in a function.
    
    Equivalent to chexify(fn) but as a decorator.
    
    Parameters:
    - fn: Function to decorate
    
    Returns:
    - Function with jittable assertions enabled
    """

def block_until_chexify_assertions_complete():
    """
    Wait for all asynchronous assertion checks to complete.
    
    Should be called after computations that use chexify to ensure
    all assertion errors are properly surfaced.
    """

class ChexifyChecks:
    """
    Collection of checkify error categories for jittable assertions.
    
    Attributes:
    - user: User-defined checks (Chex assertions)
    - nan: NaN detection checks
    - index: Array indexing checks  
    - div: Division by zero checks
    - float: Floating point error checks
    - automatic: Automatically enabled checks
    - all: All available checks
    """

Deprecation Management

Utilities for managing deprecated functions and warning users about API changes.

def warn_deprecated_function(fun, replacement=None):
    """
    Decorator to mark a function as deprecated.
    
    Emits DeprecationWarning when the decorated function is called.
    
    Parameters:
    - fun: Function to mark as deprecated
    - replacement: Optional name of replacement function
    
    Returns:
    - Wrapped function that emits deprecation warning
    """

def create_deprecated_function_alias(fun, new_name, deprecated_alias):
    """
    Create a deprecated alias for a function.
    
    Creates a new function that emits deprecation warning and delegates
    to the original function.
    
    Parameters:
    - fun: Original function
    - new_name: Current name of the function
    - deprecated_alias: Deprecated alias name
    
    Returns:
    - Deprecated alias function
    """

def warn_only_n_pos_args_in_future(fun, n):
    """
    Warn if more than n positional arguments are passed.
    
    Helps transition functions to keyword-only arguments by warning
    when too many positional arguments are used.
    
    Parameters:
    - fun: Function to wrap
    - n: Maximum number of allowed positional arguments
    
    Returns:
    - Wrapped function that warns about excess positional arguments
    """

def warn_keyword_args_only_in_future(fun):
    """
    Warn if any positional arguments are passed (keyword-only transition).
    
    Equivalent to warn_only_n_pos_args_in_future(fun, 0).
    
    Parameters:
    - fun: Function to wrap
    
    Returns:
    - Wrapped function that warns about positional arguments
    """

Usage Examples

Backend Restriction

import chex
import jax
import jax.numpy as jnp

# Ensure computation only runs on CPU
with chex.restrict_backends(allowed=['cpu']):
    @jax.jit
    def cpu_only_computation(x):
        return x ** 2
    
    result = cpu_only_computation(jnp.array([1, 2, 3]))
    # Works fine - compiles for CPU

# Prevent accidental GPU usage
with chex.restrict_backends(forbidden=['gpu', 'tpu']):
    try:
        @jax.jit(device=jax.devices('gpu')[0])  # Attempt GPU compilation
        def gpu_computation(x):
            return x + 1
        
        gpu_computation(jnp.array([1]))
    except chex.RestrictedBackendError:
        print("GPU compilation blocked as expected")

# Restrict during specific phases
def training_phase(model_fn, data):
    # Ensure training only uses CPUs (e.g., for memory reasons)
    with chex.restrict_backends(allowed=['cpu']):
        return model_fn(data)

def inference_phase(model_fn, data):
    # Allow inference on any available device
    return model_fn(data)

Dimension Mapping

import chex
import jax.numpy as jnp

# Create dimension mapping for transformer model
dims = chex.Dimensions(
    B=32,    # Batch size
    T=512,   # Sequence length  
    D=768,   # Model dimension
    H=12,    # Number of heads
    V=50000  # Vocabulary size
)

# Use dimensions for shape assertions
def transformer_layer(
    inputs,      # Shape: (B, T, D)
    weights_qkv, # Shape: (D, 3*D) 
    weights_out  # Shape: (D, D)
):
    # Validate input shapes using dimension names
    chex.assert_shape(inputs, dims['BTD'])
    chex.assert_shape(weights_qkv, (dims.D, 3 * dims.D))
    chex.assert_shape(weights_out, dims['DD'])
    
    # Compute attention
    batch_size, seq_len, model_dim = inputs.shape
    
    # Query, Key, Value projections
    qkv = jnp.dot(inputs, weights_qkv)  # (B, T, 3*D)
    qkv = qkv.reshape(batch_size, seq_len, 3, dims.H, dims.D // dims.H)
    q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
    
    # Multi-head attention computation...
    # Output shape should be (B, T, D)
    output = jnp.dot(attention_output, weights_out)
    
    chex.assert_shape(output, dims['BTD'])
    return output

# Dynamic dimension updates
def process_variable_batch(data):
    # Update batch dimension based on actual data
    dims['B'] = data.shape[0]
    
    # Use updated dimensions
    chex.assert_shape(data, dims['BTD'])
    return data

# Flattened dimensions for linear layers
def create_classifier_weights():
    # Flatten sequence and model dimensions
    input_size = dims.size('TD')  # T * D = 512 * 768
    output_size = dims.V         # Vocabulary size
    
    return jnp.ones((input_size, output_size))

# Wildcard dimensions
def flexible_attention(queries, keys, values):
    # Allow any sequence length but fixed model dimension
    chex.assert_shape(queries, dims['B*D'])  # (B, any_seq_len, D)
    chex.assert_shape(keys, dims['B*D'])     # (B, any_seq_len, D) 
    chex.assert_shape(values, dims['B*D'])   # (B, any_seq_len, D)
    
    # Attention computation...
    return attention_output

Jittable Assertions

import chex
import jax
import jax.numpy as jnp

# Enable assertions inside jitted functions
@chex.chexify  # or @chex.with_jittable_assertions
@jax.jit
def safe_division(x, y):
    # These assertions work inside jit!
    chex.assert_tree_all_finite(x)
    chex.assert_tree_all_finite(y)
    chex.assert_scalar_positive(y)  # Ensure no division by zero
    
    result = x / y
    chex.assert_tree_all_finite(result)
    return result

# Use with async checking
@chex.chexify(async_check=True)
@jax.jit 
def training_step(params, batch):
    # Assertions are checked asynchronously 
    chex.assert_tree_all_finite(params)
    chex.assert_shape(batch['inputs'], (32, 784))
    
    # Training computation...
    loss = compute_loss(params, batch)
    grads = jax.grad(compute_loss)(params, batch)
    
    chex.assert_tree_all_finite(grads)
    chex.assert_scalar_positive(loss)
    
    return grads, loss

# Block until all assertions complete
for epoch in range(num_epochs):
    for batch in dataloader:
        grads, loss = training_step(params, batch)
        params = update_params(params, grads)
    
    # Ensure all assertions from epoch have been checked
    chex.block_until_chexify_assertions_complete()
    print(f"Epoch {epoch} completed successfully")

# Configure error categories
@chex.chexify(errors=chex.ChexifyChecks.all)  # Check everything
@jax.jit
def comprehensive_checks(data):
    # Enables NaN, indexing, division, and user checks
    return jnp.mean(data)

@chex.chexify(errors=chex.ChexifyChecks.user | chex.ChexifyChecks.nan)
@jax.jit
def custom_checks(data):
    # Only user assertions and NaN checks
    return jnp.sum(data)

Deprecation Management

import chex

# Mark function as deprecated
@chex.warn_deprecated_function(replacement='new_function_name')
def old_function(x):
    """This function is deprecated."""
    return x + 1

# Create deprecated alias
def current_function(x, y):
    return x * y

# Create deprecated alias that warns users
old_function_name = chex.create_deprecated_function_alias(
    current_function, 
    'current_function',
    'old_function_name'
)

# Transition to keyword-only arguments
@chex.warn_only_n_pos_args_in_future(n=1)
def transitioning_function(required_arg, optional_arg=None, another_arg=None):
    """Function transitioning to keyword-only arguments."""
    return required_arg + (optional_arg or 0) + (another_arg or 0)

# Usage that will warn:
# transitioning_function(1, 2, 3)  # Warning: only first arg should be positional

# Preferred usage:
# transitioning_function(1, optional_arg=2, another_arg=3)  # No warning

# Force keyword-only
@chex.warn_keyword_args_only_in_future
def keyword_only_function(*, arg1, arg2):
    """Function that should only accept keyword arguments."""
    return arg1 + arg2

# This will warn:
# keyword_only_function(1, 2)  # Warning about positional args

# This is correct:
# keyword_only_function(arg1=1, arg2=2)  # No warning

Advanced Integration Patterns

import chex
import jax
import jax.numpy as jnp

class AdvancedTrainer:
    """Training class with advanced Chex features."""
    
    def __init__(self, config):
        self.config = config
        
        # Set up dimensions
        self.dims = chex.Dimensions(
            B=config.batch_size,
            T=config.sequence_length,
            D=config.model_dim,
            C=config.num_classes
        )
        
        # Configure backend restrictions
        self.allowed_backends = config.allowed_backends
    
    @chex.chexify(async_check=True)
    def create_training_step(self):
        """Create jittable training step with assertions."""
        
        def training_step(state, batch):
            # Validate inputs
            chex.assert_tree_all_finite(state.params)
            chex.assert_shape(batch['inputs'], self.dims['BTD'])
            chex.assert_shape(batch['labels'], self.dims['BC'])
            
            # Forward pass
            def loss_fn(params):
                logits = self.model.apply(params, batch['inputs'])
                chex.assert_shape(logits, self.dims['BC'])
                return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=batch['labels']
                ))
            
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            
            # Validate outputs
            chex.assert_scalar_positive(loss)
            chex.assert_tree_all_finite(grads)
            
            # Update state
            new_state = self.optimizer.update(grads, state)
            return new_state, {'loss': loss}
        
        return jax.jit(training_step)
    
    def train(self, train_data):
        """Training loop with backend restriction."""
        
        # Restrict to allowed backends during training
        with chex.restrict_backends(allowed=self.allowed_backends):
            training_step = self.create_training_step()
            
            for epoch in range(self.config.num_epochs):
                for step, batch in enumerate(train_data):
                    # Validate batch dimensions dynamically
                    actual_batch_size = batch['inputs'].shape[0]
                    if actual_batch_size != self.dims.B:
                        # Update dimensions for final batch
                        self.dims['B'] = actual_batch_size
                    
                    state, metrics = training_step(self.state, batch)
                    self.state = state
                    
                    if step % 100 == 0:
                        # Ensure all async assertions have completed
                        chex.block_until_chexify_assertions_complete()
                        self.log_metrics(metrics, epoch, step)

# Integration with existing codebases
def modernize_legacy_function():
    """Example of gradually modernizing legacy code."""
    
    # Original function (deprecated)
    @chex.warn_deprecated_function(replacement='process_data_v2')
    def process_data_v1(data, normalize, scale):
        return data * scale if normalize else data
    
    # New function with better API
    @chex.warn_only_n_pos_args_in_future(n=1)
    def process_data_v2(data, *, normalize=False, scale=1.0):
        # Add shape validation
        chex.assert_rank(data, 2)
        chex.assert_scalar_positive(scale)
        
        if normalize:
            data = data / jnp.linalg.norm(data, axis=1, keepdims=True)
        
        return data * scale
    
    # Future version (keyword-only)
    def process_data_v3(*, data, normalize=False, scale=1.0):
        # Enhanced with jittable assertions
        @chex.chexify
        @jax.jit
        def _process(data, normalize, scale):
            chex.assert_rank(data, 2)
            chex.assert_scalar_positive(scale)
            chex.assert_tree_all_finite(data)
            
            if normalize:
                norms = jnp.linalg.norm(data, axis=1, keepdims=True)
                chex.assert_tree_all_finite(norms)
                data = data / norms
            
            result = data * scale
            chex.assert_tree_all_finite(result)
            return result
        
        return _process(data, normalize, scale)

Key Features

Fine-Grained Control

  • Precise backend restrictions for different computation phases
  • Flexible dimension management with arithmetic operations
  • Configurable assertion checking with multiple error categories

Production Ready

  • Async assertion checking for minimal performance impact
  • Deprecation management for smooth API transitions
  • Integration with existing JAX transformation pipeline

Developer Friendly

  • Clear error messages and warnings
  • Readable dimension specifications
  • Comprehensive debugging support

Best Practices

Use Backend Restrictions Strategically

# Good: Restrict during specific phases
with chex.restrict_backends(allowed=['cpu']):
    # Memory-intensive preprocessing
    pass

# Avoid: Overly broad restrictions
with chex.restrict_backends(forbidden=['gpu']):
    # Entire training loop - might be unnecessarily restrictive
    pass

Design Maintainable Dimension Systems

# Good: Centralized dimension management
dims = chex.Dimensions(B=32, T=100, D=512)

# Good: Clear dimension naming
dims = chex.Dimensions(
    batch_size=32,
    sequence_length=100, 
    embedding_dim=512
)

Plan Deprecation Carefully

# Good: Provide clear migration path
@chex.warn_deprecated_function(replacement='new_api_function')
def old_function():
    pass

# Good: Gradual transition
@chex.warn_only_n_pos_args_in_future(n=1)
def transitioning_function(required, *, optional=None):
    pass

Install with Tessl CLI

npx tessl i tessl/pypi-chex

docs

advanced.md

assertions.md

dataclasses.md

debugging.md

index.md

testing.md

types.md

tile.json