CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

debugging.mddocs/

Debugging and Development Utilities

Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.

Capabilities

Fake JAX Transformations

Functions to replace JAX transformations with simpler implementations for debugging.

def fake_jit(fn, **kwargs):
    """
    Replace jax.jit with identity function for debugging.
    
    Returns the original function without compilation, enabling:
    - Step-through debugging with standard Python debuggers
    - Faster iteration during development
    - Access to intermediate values and Python control flow
    
    Parameters:
    - fn: Function that would normally be jitted
    - **kwargs: Ignored (for compatibility with jax.jit signature)
    
    Returns:
    - Original function without jit compilation
    """

def fake_pmap(fn, axis_name=None, **kwargs):
    """
    Replace jax.pmap with vmap for debugging on single device.
    
    Enables testing of pmap code on machines without multiple devices
    by replacing parallel mapping with vectorized mapping.
    
    Parameters:
    - fn: Function that would normally be pmapped
    - axis_name: Axis name (ignored in fake implementation)
    - **kwargs: Additional pmap arguments (most ignored)
    
    Returns:
    - Function wrapped with vmap instead of pmap
    """

def fake_pmap_and_jit(fn, **kwargs):
    """
    Replace both jax.pmap and jax.jit with simpler implementations.
    
    Combines fake_pmap and fake_jit behavior for comprehensive debugging
    of functions that use both transformations.
    
    Parameters:
    - fn: Function to wrap
    - **kwargs: Ignored transformation arguments
    
    Returns:
    - Function with both pmap and jit removed
    """

Device Configuration

Functions for controlling device behavior in testing environments.

def set_n_cpu_devices(n=None):
    """
    Force XLA to use n CPU threads as host devices.
    
    Enables testing of multi-device code (like pmap) on single-CPU machines
    by creating multiple virtual CPU devices.
    
    IMPORTANT: Must be called before any JAX operations or device queries.
    
    Parameters:
    - n: Number of CPU devices to create (uses FLAGS.chex_n_cpu_devices if None)
    
    Raises:
    - RuntimeError: If XLA backends are already initialized
    """

def get_n_cpu_devices_from_xla_flags():
    """
    Parse number of CPU devices from XLA environment flags.
    
    Returns:
    - Number of CPU devices configured in XLA_FLAGS (default: 1)
    """

Usage Examples

Basic Debugging Setup

import chex
import jax
import jax.numpy as jnp

# Original function with jit
@jax.jit
def compute_loss(params, data, labels):
    predictions = jnp.dot(data, params['weights']) + params['bias']
    return jnp.mean((predictions - labels) ** 2)

# For debugging, use fake_jit context manager
with chex.fake_jit():
    # Now jax.jit calls become identity functions
    @jax.jit  # This becomes a no-op
    def compute_loss_debug(params, data, labels):
        predictions = jnp.dot(data, params['weights']) + params['bias']
        # Can now set breakpoints and inspect intermediate values
        print(f"Predictions shape: {predictions.shape}")
        loss = jnp.mean((predictions - labels) ** 2)
        print(f"Loss value: {loss}")
        return loss
    
    # Function executes without compilation
    result = compute_loss_debug(params, data, labels)

Testing Multi-Device Code

# Setup multiple CPU devices for testing
chex.set_n_cpu_devices(4)  # Must be called before any JAX operations

def parallel_computation(data):
    """Function designed to run on multiple devices."""
    return jnp.sum(data, axis=-1)

# Test with fake_pmap
with chex.fake_pmap():
    # pmap becomes vmap, works on single physical device
    parallel_fn = jax.pmap(parallel_computation)
    
    # Create data for 4 "devices" 
    batch_data = jnp.ones((4, 10, 5))  # (devices, batch, features)
    result = parallel_fn(batch_data)
    
    print(f"Result shape: {result.shape}")  # (4, 10)

Comprehensive Debugging Context

def debug_training_step(state, batch):
    """Training step with comprehensive debugging."""
    
    def loss_fn(params):
        logits = apply_model(params, batch['inputs'])
        return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=batch['labels']
        ))
    
    # Compute loss and gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    
    # Update parameters
    new_params = update_params(state.params, grads, state.optimizer)
    
    return state._replace(params=new_params), loss

# Use fake transformations for debugging
with chex.fake_pmap_and_jit():
    # Both pmap and jit are disabled
    @jax.pmap  # Becomes vmap
    @jax.jit   # Becomes identity
    def debug_step(state, batch):
        return debug_training_step(state, batch)
    
    # Can step through with debugger
    new_state, loss = debug_step(training_state, data_batch)

Conditional Debugging

import os

DEBUG_MODE = os.getenv('DEBUG_JAX', '0') == '1'

def create_training_function():
    if DEBUG_MODE:
        # Development mode: disable transformations
        context = chex.fake_pmap_and_jit()
    else:
        # Production mode: use real transformations
        context = nullcontext()
    
    with context:
        @jax.pmap
        @jax.jit
        def train_step(state, batch):
            # Training logic here
            return updated_state, metrics
    
    return train_step

# Usage
train_fn = create_training_function()
# Automatically uses fake or real transformations based on DEBUG_MODE

Device Setup for Testing

def setup_test_environment():
    """Setup consistent test environment across different machines."""
    
    try:
        # Try to set up multiple CPU devices for pmap testing
        chex.set_n_cpu_devices(8)
        print("Multi-device testing enabled")
        return True
    except RuntimeError as e:
        print(f"Single-device testing only: {e}")
        return False

def test_parallel_algorithm():
    multi_device = setup_test_environment()
    
    def algorithm(data):
        return jnp.mean(data ** 2)
    
    if multi_device:
        # Test with real pmap
        parallel_fn = jax.pmap(algorithm)
        test_data = jnp.ones((8, 100))  # 8 devices, 100 features each
    else:
        # Test with fake pmap (becomes vmap)
        with chex.fake_pmap():
            parallel_fn = jax.pmap(algorithm)
            test_data = jnp.ones((2, 100))  # Fewer "devices"
    
    result = parallel_fn(test_data)
    assert result.shape[0] == test_data.shape[0]

Advanced Debugging Patterns

class DebuggableModel:
    """Model class with built-in debugging support."""
    
    def __init__(self, debug=False):
        self.debug = debug
        self._debug_context = chex.fake_jit() if debug else nullcontext()
    
    def __enter__(self):
        self._debug_context.__enter__()
        return self
    
    def __exit__(self, *args):
        self._debug_context.__exit__(*args)
    
    def forward(self, params, inputs):
        with self._debug_context:
            @jax.jit
            def _forward(params, inputs):
                # Model computation
                hidden = jnp.dot(inputs, params['W1']) + params['b1']
                if self.debug:
                    print(f"Hidden layer stats: mean={jnp.mean(hidden):.3f}")
                
                hidden = jax.nn.relu(hidden)
                output = jnp.dot(hidden, params['W2']) + params['b2']
                
                if self.debug:
                    print(f"Output layer stats: mean={jnp.mean(output):.3f}")
                
                return output
            
            return _forward(params, inputs)

# Usage
with DebuggableModel(debug=True) as model:
    predictions = model.forward(params, data)
    # Prints intermediate statistics when debug=True

Testing Framework Integration

import unittest

class TestWithDebugging(unittest.TestCase):
    
    def setUp(self):
        # Setup CPU devices for consistent testing
        try:
            chex.set_n_cpu_devices(4)
            self.multi_device = True
        except RuntimeError:
            self.multi_device = False
    
    def test_jitted_function(self):
        """Test function behavior with and without jit."""
        
        def compute_fn(x):
            return x ** 2 + 2 * x + 1
        
        x = jnp.array([1.0, 2.0, 3.0])
        
        # Test without jit (easier debugging)
        with chex.fake_jit():
            jitted_fn = jax.jit(compute_fn)
            result_fake = jitted_fn(x)
        
        # Test with real jit
        real_jitted_fn = jax.jit(compute_fn)
        result_real = real_jitted_fn(x)
        
        # Results should be identical
        chex.assert_trees_all_close(result_fake, result_real)
    
    def test_pmap_function(self):
        """Test pmap function with fake implementation."""
        
        def parallel_sum(x):
            return jnp.sum(x)
        
        if self.multi_device:
            # Test with real pmap
            pmapped_fn = jax.pmap(parallel_sum)
            test_data = jnp.ones((4, 10))
            result = pmapped_fn(test_data)
            expected_shape = (4,)
        else:
            # Test with fake pmap
            with chex.fake_pmap():
                pmapped_fn = jax.pmap(parallel_sum)
                test_data = jnp.ones((2, 10))
                result = pmapped_fn(test_data)
                expected_shape = (2,)
        
        self.assertEqual(result.shape, expected_shape)

Key Features

Non-Intrusive Debugging

  • Use context managers to temporarily disable transformations
  • Original code remains unchanged
  • Easy to toggle between debug and production modes

Multi-Device Testing

  • Test pmap code on single-device machines
  • Consistent behavior across different hardware configurations
  • Simplified development workflow

Step-Through Debugging

  • Set breakpoints in jitted functions
  • Inspect intermediate values
  • Use standard Python debugging tools

Performance Development

  • Faster iteration during development
  • Skip compilation during debugging
  • Quick testing of algorithmic changes

Best Practices

Use Context Managers

# Good: Use context managers for temporary debugging
with chex.fake_jit():
    result = my_jitted_function(data)

# Avoid: Global patching that affects other code

Set Up Devices Early

# Good: Set up devices before any JAX operations
chex.set_n_cpu_devices(4)
import jax  # JAX operations after device setup

# Avoid: Setting devices after JAX initialization

Combine with Testing

# Good: Use debugging utilities in tests
class MyTest(chex.TestCase):
    def test_with_debugging(self):
        with chex.fake_jit():
            # Test logic here
            pass

Document Debug Modes

def my_function(data, debug=False):
    """Process data with optional debugging.
    
    Args:
        data: Input data
        debug: If True, disables jit for easier debugging
    """
    context = chex.fake_jit() if debug else nullcontext()
    with context:
        # Function implementation
        pass

Install with Tessl CLI

npx tessl i tessl/pypi-chex

docs

advanced.md

assertions.md

dataclasses.md

debugging.md

index.md

testing.md

types.md

tile.json