or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md
tile.json

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/chex@0.1.x

To install, run

npx @tessl/cli install tessl/pypi-chex@0.1.0

index.mddocs/

Chex

Chex is a comprehensive utilities library designed specifically for JAX development that provides essential tools for testing, debugging, and instrumentation. The library offers robust assertion functions for validating tensor shapes, dimensions, and properties in JAX computations, debugging utilities that can transform pmaps to vmaps for easier development workflows, and comprehensive testing infrastructure that enables running the same test code across multiple variants (jitted vs non-jitted, different devices).

Package Information

  • Package Name: chex
  • Language: Python
  • Installation: pip install chex
  • Requires: Python >=3.11, JAX >=0.7.0

Core Imports

import chex

Individual modules can be imported as needed:

# Core assertions and utilities
from chex import assert_shape, assert_equal, dataclass
from chex import fake_jit, fake_pmap, variants, TestCase

Basic Usage

import chex
import jax.numpy as jnp

# Shape and dimension assertions
x = jnp.array([[1, 2, 3], [4, 5, 6]])
chex.assert_shape(x, (2, 3))  # Verify expected shape
chex.assert_rank(x, 2)        # Verify number of dimensions

# Testing with variants - run same test jitted and non-jitted
class MyTest(chex.TestCase):
    
    @chex.variants(with_jit=True, without_jit=True)
    def test_my_function(self):
        def my_fn(x):
            return x + 1
        
        fn = self.variant(my_fn)  # Automatically jitted or not
        result = fn(jnp.array([1, 2, 3]))
        chex.assert_equal(result, jnp.array([2, 3, 4]))

# JAX-compatible dataclasses
@chex.dataclass
class Config:
    learning_rate: float
    batch_size: int
    
config = Config(learning_rate=0.01, batch_size=32)
# Works seamlessly with JAX transformations like jit, vmap, etc.

# Debugging utilities
with chex.fake_jit():
    # JAX jit calls become identity functions for easier debugging
    jitted_fn = jax.jit(lambda x: x * 2)
    result = jitted_fn(5)  # Executes without compilation

Architecture

Chex is organized into functional modules that address different aspects of JAX development:

  • Assertions: Comprehensive validation functions for shapes, values, types, and device placement
  • Testing: Variant-based testing framework for running tests across different JAX execution modes
  • Dataclasses: JAX-compatible dataclass implementation with PyTree integration
  • Type System: Complete type definitions for JAX arrays, shapes, and computational primitives
  • Debugging: Utilities to patch JAX functions for easier development and debugging workflows
  • Backend Management: Tools for controlling and restricting JAX backend compilation

Capabilities

Assertion Functions

Comprehensive validation utilities for JAX computations including shape assertions, value comparisons, device placement checks, and tree structure validation. Essential for robust JAX code development and debugging.

def assert_shape(array, expected_shape): ...
def assert_equal(first, second): ...
def assert_rank(array, expected_rank): ...
def assert_tree_all_finite(tree): ...
def assert_devices_available(devices): ...

Assertions

JAX-Compatible Dataclasses

JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations, providing structured data containers that integrate with the JAX ecosystem.

def dataclass(cls, *, init=True, repr=True, eq=True, **kwargs): ...
def mappable_dataclass(cls): ...
def register_dataclass_type_with_jax_tree_util(cls): ...

Dataclasses

Test Variants and Testing Infrastructure

Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation.

class TestCase: ...
def variants(*variant_types): ...
def all_variants(*variant_types): ...
class ChexVariantType(Enum): ...
def params_product(*params_lists, named=False): ...

Testing

Type Definitions

Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.

Array = Union[ArrayDevice, ArrayBatched, ArraySharded, ArrayNumpy, ...]
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
Scalar = Union[float, int]
Shape = Sequence[Union[int, Any]]

Types

Debugging and Development Utilities

Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.

def fake_jit(fn, **kwargs): ...
def fake_pmap(fn, **kwargs): ...
def fake_pmap_and_jit(fn, **kwargs): ...
def set_n_cpu_devices(n): ...

Debugging

Advanced Features

Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.

def restrict_backends(*, allowed=None, forbidden=None): ...
class Dimensions(**kwargs): ...
def chexify(fn, async_check=True, errors=...): ...

Advanced