Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.
Context manager for controlling JAX backend compilation and device usage.
def restrict_backends(*, allowed=None, forbidden=None):
"""
Context manager that prevents JAX compilation for specified backends.
Useful for ensuring code runs only on intended devices or catching
accidental compilation on restricted hardware.
Parameters:
- allowed: Sequence of allowed backend platform names (e.g., ['cpu', 'gpu'])
- forbidden: Sequence of forbidden backend platform names
Yields:
- Context where compilation for forbidden platforms raises RestrictedBackendError
Raises:
- ValueError: If neither allowed nor forbidden specified, or if conflicts exist
- RestrictedBackendError: If compilation attempted on restricted backend
"""
class RestrictedBackendError(RuntimeError):
"""
Exception raised when compilation attempted on restricted backend.
"""Utility class for managing named dimensions and shape specifications.
class Dimensions:
"""
Lightweight utility that maps strings to shape tuples.
Enables readable shape specifications using named dimensions
and supports dimension arithmetic and wildcard dimensions.
Examples:
>>> dims = chex.Dimensions(B=3, T=5, N=7)
>>> dims['NBT'] # (7, 3, 5)
>>> dims['(BT)N'] # (15, 7) - flattened dimensions
>>> dims['BT*'] # (3, 5, None) - wildcard dimension
"""
def __init__(self, **kwargs):
"""
Initialize dimensions with named size mappings.
Parameters:
- **kwargs: Dimension name to size mappings (e.g., B=32, T=100)
"""
def __getitem__(self, key):
"""
Get shape tuple for dimension string specification.
Parameters:
- key: String specifying dimensions (e.g., 'BTC', '(BT)C', 'BT*')
Returns:
- Tuple of integers and/or None for wildcard dimensions
"""
def __setitem__(self, key, value):
"""
Set dimension sizes from shape tuple.
Parameters:
- key: String specifying dimensions
- value: Shape tuple to assign to dimensions
"""
def size(self, key):
"""
Get total size (product) of specified dimensions.
Parameters:
- key: String specifying dimensions
Returns:
- Total number of elements in the specified shape
"""Advanced assertion system that works inside jitted functions using JAX checkify.
def chexify(
fn,
async_check=True,
errors=ChexifyChecks.user
):
"""
Enable Chex value assertions inside jitted functions.
Wraps function to enable runtime assertions that work with JAX transformations
by using JAX's checkify system for delayed error checking.
Parameters:
- fn: Function to wrap with jittable assertions
- async_check: Whether to check errors asynchronously
- errors: Set of error categories to check (from ChexifyChecks)
Returns:
- Wrapped function that supports Chex assertions inside jit
"""
def with_jittable_assertions(fn):
"""
Decorator for enabling jittable assertions in a function.
Equivalent to chexify(fn) but as a decorator.
Parameters:
- fn: Function to decorate
Returns:
- Function with jittable assertions enabled
"""
def block_until_chexify_assertions_complete():
"""
Wait for all asynchronous assertion checks to complete.
Should be called after computations that use chexify to ensure
all assertion errors are properly surfaced.
"""
class ChexifyChecks:
"""
Collection of checkify error categories for jittable assertions.
Attributes:
- user: User-defined checks (Chex assertions)
- nan: NaN detection checks
- index: Array indexing checks
- div: Division by zero checks
- float: Floating point error checks
- automatic: Automatically enabled checks
- all: All available checks
"""Utilities for managing deprecated functions and warning users about API changes.
def warn_deprecated_function(fun, replacement=None):
"""
Decorator to mark a function as deprecated.
Emits DeprecationWarning when the decorated function is called.
Parameters:
- fun: Function to mark as deprecated
- replacement: Optional name of replacement function
Returns:
- Wrapped function that emits deprecation warning
"""
def create_deprecated_function_alias(fun, new_name, deprecated_alias):
"""
Create a deprecated alias for a function.
Creates a new function that emits deprecation warning and delegates
to the original function.
Parameters:
- fun: Original function
- new_name: Current name of the function
- deprecated_alias: Deprecated alias name
Returns:
- Deprecated alias function
"""
def warn_only_n_pos_args_in_future(fun, n):
"""
Warn if more than n positional arguments are passed.
Helps transition functions to keyword-only arguments by warning
when too many positional arguments are used.
Parameters:
- fun: Function to wrap
- n: Maximum number of allowed positional arguments
Returns:
- Wrapped function that warns about excess positional arguments
"""
def warn_keyword_args_only_in_future(fun):
"""
Warn if any positional arguments are passed (keyword-only transition).
Equivalent to warn_only_n_pos_args_in_future(fun, 0).
Parameters:
- fun: Function to wrap
Returns:
- Wrapped function that warns about positional arguments
"""import chex
import jax
import jax.numpy as jnp
# Ensure computation only runs on CPU
with chex.restrict_backends(allowed=['cpu']):
@jax.jit
def cpu_only_computation(x):
return x ** 2
result = cpu_only_computation(jnp.array([1, 2, 3]))
# Works fine - compiles for CPU
# Prevent accidental GPU usage
with chex.restrict_backends(forbidden=['gpu', 'tpu']):
try:
@jax.jit(device=jax.devices('gpu')[0]) # Attempt GPU compilation
def gpu_computation(x):
return x + 1
gpu_computation(jnp.array([1]))
except chex.RestrictedBackendError:
print("GPU compilation blocked as expected")
# Restrict during specific phases
def training_phase(model_fn, data):
# Ensure training only uses CPUs (e.g., for memory reasons)
with chex.restrict_backends(allowed=['cpu']):
return model_fn(data)
def inference_phase(model_fn, data):
# Allow inference on any available device
return model_fn(data)import chex
import jax.numpy as jnp
# Create dimension mapping for transformer model
dims = chex.Dimensions(
B=32, # Batch size
T=512, # Sequence length
D=768, # Model dimension
H=12, # Number of heads
V=50000 # Vocabulary size
)
# Use dimensions for shape assertions
def transformer_layer(
inputs, # Shape: (B, T, D)
weights_qkv, # Shape: (D, 3*D)
weights_out # Shape: (D, D)
):
# Validate input shapes using dimension names
chex.assert_shape(inputs, dims['BTD'])
chex.assert_shape(weights_qkv, (dims.D, 3 * dims.D))
chex.assert_shape(weights_out, dims['DD'])
# Compute attention
batch_size, seq_len, model_dim = inputs.shape
# Query, Key, Value projections
qkv = jnp.dot(inputs, weights_qkv) # (B, T, 3*D)
qkv = qkv.reshape(batch_size, seq_len, 3, dims.H, dims.D // dims.H)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Multi-head attention computation...
# Output shape should be (B, T, D)
output = jnp.dot(attention_output, weights_out)
chex.assert_shape(output, dims['BTD'])
return output
# Dynamic dimension updates
def process_variable_batch(data):
# Update batch dimension based on actual data
dims['B'] = data.shape[0]
# Use updated dimensions
chex.assert_shape(data, dims['BTD'])
return data
# Flattened dimensions for linear layers
def create_classifier_weights():
# Flatten sequence and model dimensions
input_size = dims.size('TD') # T * D = 512 * 768
output_size = dims.V # Vocabulary size
return jnp.ones((input_size, output_size))
# Wildcard dimensions
def flexible_attention(queries, keys, values):
# Allow any sequence length but fixed model dimension
chex.assert_shape(queries, dims['B*D']) # (B, any_seq_len, D)
chex.assert_shape(keys, dims['B*D']) # (B, any_seq_len, D)
chex.assert_shape(values, dims['B*D']) # (B, any_seq_len, D)
# Attention computation...
return attention_outputimport chex
import jax
import jax.numpy as jnp
# Enable assertions inside jitted functions
@chex.chexify # or @chex.with_jittable_assertions
@jax.jit
def safe_division(x, y):
# These assertions work inside jit!
chex.assert_tree_all_finite(x)
chex.assert_tree_all_finite(y)
chex.assert_scalar_positive(y) # Ensure no division by zero
result = x / y
chex.assert_tree_all_finite(result)
return result
# Use with async checking
@chex.chexify(async_check=True)
@jax.jit
def training_step(params, batch):
# Assertions are checked asynchronously
chex.assert_tree_all_finite(params)
chex.assert_shape(batch['inputs'], (32, 784))
# Training computation...
loss = compute_loss(params, batch)
grads = jax.grad(compute_loss)(params, batch)
chex.assert_tree_all_finite(grads)
chex.assert_scalar_positive(loss)
return grads, loss
# Block until all assertions complete
for epoch in range(num_epochs):
for batch in dataloader:
grads, loss = training_step(params, batch)
params = update_params(params, grads)
# Ensure all assertions from epoch have been checked
chex.block_until_chexify_assertions_complete()
print(f"Epoch {epoch} completed successfully")
# Configure error categories
@chex.chexify(errors=chex.ChexifyChecks.all) # Check everything
@jax.jit
def comprehensive_checks(data):
# Enables NaN, indexing, division, and user checks
return jnp.mean(data)
@chex.chexify(errors=chex.ChexifyChecks.user | chex.ChexifyChecks.nan)
@jax.jit
def custom_checks(data):
# Only user assertions and NaN checks
return jnp.sum(data)import chex
# Mark function as deprecated
@chex.warn_deprecated_function(replacement='new_function_name')
def old_function(x):
"""This function is deprecated."""
return x + 1
# Create deprecated alias
def current_function(x, y):
return x * y
# Create deprecated alias that warns users
old_function_name = chex.create_deprecated_function_alias(
current_function,
'current_function',
'old_function_name'
)
# Transition to keyword-only arguments
@chex.warn_only_n_pos_args_in_future(n=1)
def transitioning_function(required_arg, optional_arg=None, another_arg=None):
"""Function transitioning to keyword-only arguments."""
return required_arg + (optional_arg or 0) + (another_arg or 0)
# Usage that will warn:
# transitioning_function(1, 2, 3) # Warning: only first arg should be positional
# Preferred usage:
# transitioning_function(1, optional_arg=2, another_arg=3) # No warning
# Force keyword-only
@chex.warn_keyword_args_only_in_future
def keyword_only_function(*, arg1, arg2):
"""Function that should only accept keyword arguments."""
return arg1 + arg2
# This will warn:
# keyword_only_function(1, 2) # Warning about positional args
# This is correct:
# keyword_only_function(arg1=1, arg2=2) # No warningimport chex
import jax
import jax.numpy as jnp
class AdvancedTrainer:
"""Training class with advanced Chex features."""
def __init__(self, config):
self.config = config
# Set up dimensions
self.dims = chex.Dimensions(
B=config.batch_size,
T=config.sequence_length,
D=config.model_dim,
C=config.num_classes
)
# Configure backend restrictions
self.allowed_backends = config.allowed_backends
@chex.chexify(async_check=True)
def create_training_step(self):
"""Create jittable training step with assertions."""
def training_step(state, batch):
# Validate inputs
chex.assert_tree_all_finite(state.params)
chex.assert_shape(batch['inputs'], self.dims['BTD'])
chex.assert_shape(batch['labels'], self.dims['BC'])
# Forward pass
def loss_fn(params):
logits = self.model.apply(params, batch['inputs'])
chex.assert_shape(logits, self.dims['BC'])
return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=batch['labels']
))
loss, grads = jax.value_and_grad(loss_fn)(state.params)
# Validate outputs
chex.assert_scalar_positive(loss)
chex.assert_tree_all_finite(grads)
# Update state
new_state = self.optimizer.update(grads, state)
return new_state, {'loss': loss}
return jax.jit(training_step)
def train(self, train_data):
"""Training loop with backend restriction."""
# Restrict to allowed backends during training
with chex.restrict_backends(allowed=self.allowed_backends):
training_step = self.create_training_step()
for epoch in range(self.config.num_epochs):
for step, batch in enumerate(train_data):
# Validate batch dimensions dynamically
actual_batch_size = batch['inputs'].shape[0]
if actual_batch_size != self.dims.B:
# Update dimensions for final batch
self.dims['B'] = actual_batch_size
state, metrics = training_step(self.state, batch)
self.state = state
if step % 100 == 0:
# Ensure all async assertions have completed
chex.block_until_chexify_assertions_complete()
self.log_metrics(metrics, epoch, step)
# Integration with existing codebases
def modernize_legacy_function():
"""Example of gradually modernizing legacy code."""
# Original function (deprecated)
@chex.warn_deprecated_function(replacement='process_data_v2')
def process_data_v1(data, normalize, scale):
return data * scale if normalize else data
# New function with better API
@chex.warn_only_n_pos_args_in_future(n=1)
def process_data_v2(data, *, normalize=False, scale=1.0):
# Add shape validation
chex.assert_rank(data, 2)
chex.assert_scalar_positive(scale)
if normalize:
data = data / jnp.linalg.norm(data, axis=1, keepdims=True)
return data * scale
# Future version (keyword-only)
def process_data_v3(*, data, normalize=False, scale=1.0):
# Enhanced with jittable assertions
@chex.chexify
@jax.jit
def _process(data, normalize, scale):
chex.assert_rank(data, 2)
chex.assert_scalar_positive(scale)
chex.assert_tree_all_finite(data)
if normalize:
norms = jnp.linalg.norm(data, axis=1, keepdims=True)
chex.assert_tree_all_finite(norms)
data = data / norms
result = data * scale
chex.assert_tree_all_finite(result)
return result
return _process(data, normalize, scale)# Good: Restrict during specific phases
with chex.restrict_backends(allowed=['cpu']):
# Memory-intensive preprocessing
pass
# Avoid: Overly broad restrictions
with chex.restrict_backends(forbidden=['gpu']):
# Entire training loop - might be unnecessarily restrictive
pass# Good: Centralized dimension management
dims = chex.Dimensions(B=32, T=100, D=512)
# Good: Clear dimension naming
dims = chex.Dimensions(
batch_size=32,
sequence_length=100,
embedding_dim=512
)# Good: Provide clear migration path
@chex.warn_deprecated_function(replacement='new_api_function')
def old_function():
pass
# Good: Gradual transition
@chex.warn_only_n_pos_args_in_future(n=1)
def transitioning_function(required, *, optional=None):
passInstall with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10