Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.
Functions to replace JAX transformations with simpler implementations for debugging.
def fake_jit(fn, **kwargs):
"""
Replace jax.jit with identity function for debugging.
Returns the original function without compilation, enabling:
- Step-through debugging with standard Python debuggers
- Faster iteration during development
- Access to intermediate values and Python control flow
Parameters:
- fn: Function that would normally be jitted
- **kwargs: Ignored (for compatibility with jax.jit signature)
Returns:
- Original function without jit compilation
"""
def fake_pmap(fn, axis_name=None, **kwargs):
"""
Replace jax.pmap with vmap for debugging on single device.
Enables testing of pmap code on machines without multiple devices
by replacing parallel mapping with vectorized mapping.
Parameters:
- fn: Function that would normally be pmapped
- axis_name: Axis name (ignored in fake implementation)
- **kwargs: Additional pmap arguments (most ignored)
Returns:
- Function wrapped with vmap instead of pmap
"""
def fake_pmap_and_jit(fn, **kwargs):
"""
Replace both jax.pmap and jax.jit with simpler implementations.
Combines fake_pmap and fake_jit behavior for comprehensive debugging
of functions that use both transformations.
Parameters:
- fn: Function to wrap
- **kwargs: Ignored transformation arguments
Returns:
- Function with both pmap and jit removed
"""Functions for controlling device behavior in testing environments.
def set_n_cpu_devices(n=None):
"""
Force XLA to use n CPU threads as host devices.
Enables testing of multi-device code (like pmap) on single-CPU machines
by creating multiple virtual CPU devices.
IMPORTANT: Must be called before any JAX operations or device queries.
Parameters:
- n: Number of CPU devices to create (uses FLAGS.chex_n_cpu_devices if None)
Raises:
- RuntimeError: If XLA backends are already initialized
"""
def get_n_cpu_devices_from_xla_flags():
"""
Parse number of CPU devices from XLA environment flags.
Returns:
- Number of CPU devices configured in XLA_FLAGS (default: 1)
"""import chex
import jax
import jax.numpy as jnp
# Original function with jit
@jax.jit
def compute_loss(params, data, labels):
predictions = jnp.dot(data, params['weights']) + params['bias']
return jnp.mean((predictions - labels) ** 2)
# For debugging, use fake_jit context manager
with chex.fake_jit():
# Now jax.jit calls become identity functions
@jax.jit # This becomes a no-op
def compute_loss_debug(params, data, labels):
predictions = jnp.dot(data, params['weights']) + params['bias']
# Can now set breakpoints and inspect intermediate values
print(f"Predictions shape: {predictions.shape}")
loss = jnp.mean((predictions - labels) ** 2)
print(f"Loss value: {loss}")
return loss
# Function executes without compilation
result = compute_loss_debug(params, data, labels)# Setup multiple CPU devices for testing
chex.set_n_cpu_devices(4) # Must be called before any JAX operations
def parallel_computation(data):
"""Function designed to run on multiple devices."""
return jnp.sum(data, axis=-1)
# Test with fake_pmap
with chex.fake_pmap():
# pmap becomes vmap, works on single physical device
parallel_fn = jax.pmap(parallel_computation)
# Create data for 4 "devices"
batch_data = jnp.ones((4, 10, 5)) # (devices, batch, features)
result = parallel_fn(batch_data)
print(f"Result shape: {result.shape}") # (4, 10)def debug_training_step(state, batch):
"""Training step with comprehensive debugging."""
def loss_fn(params):
logits = apply_model(params, batch['inputs'])
return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=batch['labels']
))
# Compute loss and gradients
loss, grads = jax.value_and_grad(loss_fn)(state.params)
# Update parameters
new_params = update_params(state.params, grads, state.optimizer)
return state._replace(params=new_params), loss
# Use fake transformations for debugging
with chex.fake_pmap_and_jit():
# Both pmap and jit are disabled
@jax.pmap # Becomes vmap
@jax.jit # Becomes identity
def debug_step(state, batch):
return debug_training_step(state, batch)
# Can step through with debugger
new_state, loss = debug_step(training_state, data_batch)import os
DEBUG_MODE = os.getenv('DEBUG_JAX', '0') == '1'
def create_training_function():
if DEBUG_MODE:
# Development mode: disable transformations
context = chex.fake_pmap_and_jit()
else:
# Production mode: use real transformations
context = nullcontext()
with context:
@jax.pmap
@jax.jit
def train_step(state, batch):
# Training logic here
return updated_state, metrics
return train_step
# Usage
train_fn = create_training_function()
# Automatically uses fake or real transformations based on DEBUG_MODEdef setup_test_environment():
"""Setup consistent test environment across different machines."""
try:
# Try to set up multiple CPU devices for pmap testing
chex.set_n_cpu_devices(8)
print("Multi-device testing enabled")
return True
except RuntimeError as e:
print(f"Single-device testing only: {e}")
return False
def test_parallel_algorithm():
multi_device = setup_test_environment()
def algorithm(data):
return jnp.mean(data ** 2)
if multi_device:
# Test with real pmap
parallel_fn = jax.pmap(algorithm)
test_data = jnp.ones((8, 100)) # 8 devices, 100 features each
else:
# Test with fake pmap (becomes vmap)
with chex.fake_pmap():
parallel_fn = jax.pmap(algorithm)
test_data = jnp.ones((2, 100)) # Fewer "devices"
result = parallel_fn(test_data)
assert result.shape[0] == test_data.shape[0]class DebuggableModel:
"""Model class with built-in debugging support."""
def __init__(self, debug=False):
self.debug = debug
self._debug_context = chex.fake_jit() if debug else nullcontext()
def __enter__(self):
self._debug_context.__enter__()
return self
def __exit__(self, *args):
self._debug_context.__exit__(*args)
def forward(self, params, inputs):
with self._debug_context:
@jax.jit
def _forward(params, inputs):
# Model computation
hidden = jnp.dot(inputs, params['W1']) + params['b1']
if self.debug:
print(f"Hidden layer stats: mean={jnp.mean(hidden):.3f}")
hidden = jax.nn.relu(hidden)
output = jnp.dot(hidden, params['W2']) + params['b2']
if self.debug:
print(f"Output layer stats: mean={jnp.mean(output):.3f}")
return output
return _forward(params, inputs)
# Usage
with DebuggableModel(debug=True) as model:
predictions = model.forward(params, data)
# Prints intermediate statistics when debug=Trueimport unittest
class TestWithDebugging(unittest.TestCase):
def setUp(self):
# Setup CPU devices for consistent testing
try:
chex.set_n_cpu_devices(4)
self.multi_device = True
except RuntimeError:
self.multi_device = False
def test_jitted_function(self):
"""Test function behavior with and without jit."""
def compute_fn(x):
return x ** 2 + 2 * x + 1
x = jnp.array([1.0, 2.0, 3.0])
# Test without jit (easier debugging)
with chex.fake_jit():
jitted_fn = jax.jit(compute_fn)
result_fake = jitted_fn(x)
# Test with real jit
real_jitted_fn = jax.jit(compute_fn)
result_real = real_jitted_fn(x)
# Results should be identical
chex.assert_trees_all_close(result_fake, result_real)
def test_pmap_function(self):
"""Test pmap function with fake implementation."""
def parallel_sum(x):
return jnp.sum(x)
if self.multi_device:
# Test with real pmap
pmapped_fn = jax.pmap(parallel_sum)
test_data = jnp.ones((4, 10))
result = pmapped_fn(test_data)
expected_shape = (4,)
else:
# Test with fake pmap
with chex.fake_pmap():
pmapped_fn = jax.pmap(parallel_sum)
test_data = jnp.ones((2, 10))
result = pmapped_fn(test_data)
expected_shape = (2,)
self.assertEqual(result.shape, expected_shape)# Good: Use context managers for temporary debugging
with chex.fake_jit():
result = my_jitted_function(data)
# Avoid: Global patching that affects other code# Good: Set up devices before any JAX operations
chex.set_n_cpu_devices(4)
import jax # JAX operations after device setup
# Avoid: Setting devices after JAX initialization# Good: Use debugging utilities in tests
class MyTest(chex.TestCase):
def test_with_debugging(self):
with chex.fake_jit():
# Test logic here
passdef my_function(data, debug=False):
"""Process data with optional debugging.
Args:
data: Input data
debug: If True, disables jit for easier debugging
"""
context = chex.fake_jit() if debug else nullcontext()
with context:
# Function implementation
passInstall with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10