Comprehensive utilities library for JAX testing, debugging, and instrumentation
npx @tessl/cli install tessl/pypi-chex@0.1.0Chex is a comprehensive utilities library designed specifically for JAX development that provides essential tools for testing, debugging, and instrumentation. The library offers robust assertion functions for validating tensor shapes, dimensions, and properties in JAX computations, debugging utilities that can transform pmaps to vmaps for easier development workflows, and comprehensive testing infrastructure that enables running the same test code across multiple variants (jitted vs non-jitted, different devices).
pip install cheximport chexIndividual modules can be imported as needed:
# Core assertions and utilities
from chex import assert_shape, assert_equal, dataclass
from chex import fake_jit, fake_pmap, variants, TestCaseimport chex
import jax.numpy as jnp
# Shape and dimension assertions
x = jnp.array([[1, 2, 3], [4, 5, 6]])
chex.assert_shape(x, (2, 3)) # Verify expected shape
chex.assert_rank(x, 2) # Verify number of dimensions
# Testing with variants - run same test jitted and non-jitted
class MyTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test_my_function(self):
def my_fn(x):
return x + 1
fn = self.variant(my_fn) # Automatically jitted or not
result = fn(jnp.array([1, 2, 3]))
chex.assert_equal(result, jnp.array([2, 3, 4]))
# JAX-compatible dataclasses
@chex.dataclass
class Config:
learning_rate: float
batch_size: int
config = Config(learning_rate=0.01, batch_size=32)
# Works seamlessly with JAX transformations like jit, vmap, etc.
# Debugging utilities
with chex.fake_jit():
# JAX jit calls become identity functions for easier debugging
jitted_fn = jax.jit(lambda x: x * 2)
result = jitted_fn(5) # Executes without compilationChex is organized into functional modules that address different aspects of JAX development:
Comprehensive validation utilities for JAX computations including shape assertions, value comparisons, device placement checks, and tree structure validation. Essential for robust JAX code development and debugging.
def assert_shape(array, expected_shape): ...
def assert_equal(first, second): ...
def assert_rank(array, expected_rank): ...
def assert_tree_all_finite(tree): ...
def assert_devices_available(devices): ...JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations, providing structured data containers that integrate with the JAX ecosystem.
def dataclass(cls, *, init=True, repr=True, eq=True, **kwargs): ...
def mappable_dataclass(cls): ...
def register_dataclass_type_with_jax_tree_util(cls): ...Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation.
class TestCase: ...
def variants(*variant_types): ...
def all_variants(*variant_types): ...
class ChexVariantType(Enum): ...
def params_product(*params_lists, named=False): ...Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.
Array = Union[ArrayDevice, ArrayBatched, ArraySharded, ArrayNumpy, ...]
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
Scalar = Union[float, int]
Shape = Sequence[Union[int, Any]]Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.
def fake_jit(fn, **kwargs): ...
def fake_pmap(fn, **kwargs): ...
def fake_pmap_and_jit(fn, **kwargs): ...
def set_n_cpu_devices(n): ...Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.
def restrict_backends(*, allowed=None, forbidden=None): ...
class Dimensions(**kwargs): ...
def chexify(fn, async_check=True, errors=...): ...