Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation of JAX code behavior.
Base test case class providing variant testing infrastructure.
class TestCase(parameterized.TestCase):
"""
Base class for Chex tests that use variants.
Provides infrastructure for running tests across multiple JAX execution modes.
Subclasses from absl.testing.parameterized.TestCase to support generator unrolling.
"""
def variant(self, *args, **kwargs):
"""
Access the current test variant function.
This method is dynamically replaced by the @variants decorator
with the appropriate transformation (jit, identity, etc.).
Raises:
- RuntimeError: If called without @variants decorator
"""Decorators for running tests across multiple execution modes.
def variants(*variant_types):
"""
Decorator to run test across specified variants.
Parameters:
- *variant_types: ChexVariantType values specifying which variants to test
Returns:
- Generator yielding one test per variant
Example:
@variants(ChexVariantType.WITH_JIT, ChexVariantType.WITHOUT_JIT)
def test_function(self):
fn = self.variant(my_function)
# Test implementation
"""
def all_variants(*variant_types):
"""
Decorator to run test across all available variants.
Parameters:
- *variant_types: Optional variant types to include (defaults to all)
Returns:
- Generator yielding one test per variant
"""Enumeration of available test variant types.
class ChexVariantType(Enum):
"""
Enumeration of available Chex test variants.
Use self.variant.type to get the type of the current test variant.
"""
WITH_JIT = 1 # Function wrapped with jax.jit
WITHOUT_JIT = 2 # Function executed directly (identity)
WITH_DEVICE = 3 # Function executed on specific device
WITHOUT_DEVICE = 4 # Function executed on default device
WITH_PMAP = 5 # Function wrapped with jax.pmapUtilities for generating test parameter combinations.
def params_product(*params_lists, named=False):
"""
Generate cartesian product of parameter lists for parameterized tests.
Parameters:
- *params_lists: Sequences of parameter values
- named: Whether to generate test names for parameterized.named_parameters
Returns:
- Sequence of parameter combinations
Example:
# Generate all combinations of batch sizes and learning rates
params = params_product([32, 64], [0.01, 0.001])
# [(32, 0.01), (32, 0.001), (64, 0.01), (64, 0.001)]
"""import chex
import jax
import jax.numpy as jnp
class MyTest(chex.TestCase):
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
def test_addition_function(self):
def add_one(x):
return x + 1
# Get the variant-appropriate version of the function
fn = self.variant(add_one)
# Test the function
result = fn(jnp.array([1, 2, 3]))
expected = jnp.array([2, 3, 4])
chex.assert_equal(result, expected)
# Access variant type if needed
if self.variant.type == chex.ChexVariantType.WITH_JIT:
# This test is running with jit
passclass ComprehensiveTest(chex.TestCase):
@chex.all_variants
def test_matrix_multiply(self):
def matmul(a, b):
return jnp.dot(a, b)
fn = self.variant(matmul)
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])
result = fn(a, b)
expected = jnp.array([[19, 22], [43, 50]])
chex.assert_equal(result, expected)from absl.testing import parameterized
class ParameterizedVariantTest(chex.TestCase):
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
@parameterized.parameters(
{'batch_size': 32, 'input_dim': 784},
{'batch_size': 64, 'input_dim': 1024},
)
def test_neural_network_layer(self, batch_size, input_dim):
def linear_layer(x, weights, bias):
return jnp.dot(x, weights) + bias
fn = self.variant(linear_layer)
# Create test data
x = jnp.ones((batch_size, input_dim))
weights = jnp.ones((input_dim, 10))
bias = jnp.zeros(10)
result = fn(x, weights, bias)
# Verify output shape
chex.assert_shape(result, (batch_size, 10))
# Verify computation
expected = jnp.full((batch_size, 10), input_dim)
chex.assert_equal(result, expected)class ProductTest(chex.TestCase):
# Generate all combinations of optimizers and learning rates
@parameterized.parameters(
*chex.params_product(
['sgd', 'adam', 'rmsprop'],
[0.1, 0.01, 0.001],
named=True
)
)
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
def test_optimizer_update(self, optimizer_name, learning_rate):
def update_step(params, grads, lr):
return params - lr * grads
fn = self.variant(update_step)
params = jnp.array([1.0, 2.0, 3.0])
grads = jnp.array([0.1, 0.2, 0.3])
updated_params = fn(params, grads, learning_rate)
expected = params - learning_rate * grads
chex.assert_trees_all_close(updated_params, expected)class DeviceTest(chex.TestCase):
@chex.variants(
chex.ChexVariantType.WITH_DEVICE,
chex.ChexVariantType.WITHOUT_DEVICE
)
def test_device_placement(self):
def compute_sum(x):
return jnp.sum(x)
fn = self.variant(compute_sum)
x = jnp.array([1, 2, 3, 4, 5])
result = fn(x)
chex.assert_equal(result, 15)
# Can check device placement if needed
if hasattr(result, 'device'):
# Verify device placement based on variant type
passclass PmapTest(chex.TestCase):
@chex.variants(chex.ChexVariantType.WITH_PMAP)
def test_parallel_computation(self):
def parallel_square(x):
return x ** 2
fn = self.variant(parallel_square)
# Create data for multiple devices
n_devices = jax.local_device_count()
x = jnp.arange(n_devices * 4).reshape(n_devices, 4)
result = fn(x)
expected = x ** 2
chex.assert_equal(result, expected)class AdvancedVariantTest(chex.TestCase):
def setUp(self):
super().setUp()
# Setup that runs before each variant
self.tolerance = 1e-6
@chex.all_variants
def test_gradient_computation(self):
def loss_fn(params, data):
return jnp.sum((params['w'] @ data - params['b']) ** 2)
# Get variant-appropriate version
loss_fn = self.variant(loss_fn)
# Create test data
params = {'w': jnp.array([[1.0, 2.0]]), 'b': jnp.array([0.5])}
data = jnp.array([[1.0], [2.0]])
# Compute gradients
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, data)
# Verify gradient structure matches params
chex.assert_trees_all_equal_structs(grads, params)
# Verify gradients are finite
chex.assert_tree_all_finite(grads)
def test_variant_type_specific_behavior(self):
"""Test that demonstrates variant-specific testing logic."""
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
def _test_impl(self):
def expensive_computation(x):
# Some computation that might behave differently jitted/non-jitted
return jnp.sum(jnp.sin(x) * jnp.cos(x))
fn = self.variant(expensive_computation)
x = jnp.linspace(0, 2 * jnp.pi, 1000)
result = fn(x)
# Different expectations based on variant type
if self.variant.type == chex.ChexVariantType.WITH_JIT:
# Jitted version might have slightly different numerical behavior
chex.assert_scalar(result)
else:
# Non-jitted version
chex.assert_scalar(result)
# Execute the test
_test_impl(self)@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
def test_neural_network_forward_pass_consistency(self):
# Clear test purpose
pass# Focus variant testing on functions that will be jitted in practice
@chex.all_variants
def test_training_step(self):
# This will be jitted in real usage
pass@chex.all_variants
def test_with_comprehensive_checks(self):
fn = self.variant(my_function)
result = fn(input_data)
# Use Chex assertions for thorough validation
chex.assert_shape(result, expected_shape)
chex.assert_tree_all_finite(result)Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10