Comprehensive utilities library for JAX testing, debugging, and instrumentation
npx @tessl/cli install tessl/pypi-chex@0.1.00
# Chex
1
2
Chex is a comprehensive utilities library designed specifically for JAX development that provides essential tools for testing, debugging, and instrumentation. The library offers robust assertion functions for validating tensor shapes, dimensions, and properties in JAX computations, debugging utilities that can transform pmaps to vmaps for easier development workflows, and comprehensive testing infrastructure that enables running the same test code across multiple variants (jitted vs non-jitted, different devices).
3
4
## Package Information
5
6
- **Package Name**: chex
7
- **Language**: Python
8
- **Installation**: `pip install chex`
9
- **Requires**: Python >=3.11, JAX >=0.7.0
10
11
## Core Imports
12
13
```python
14
import chex
15
```
16
17
Individual modules can be imported as needed:
18
19
```python
20
# Core assertions and utilities
21
from chex import assert_shape, assert_equal, dataclass
22
from chex import fake_jit, fake_pmap, variants, TestCase
23
```
24
25
## Basic Usage
26
27
```python
28
import chex
29
import jax.numpy as jnp
30
31
# Shape and dimension assertions
32
x = jnp.array([[1, 2, 3], [4, 5, 6]])
33
chex.assert_shape(x, (2, 3)) # Verify expected shape
34
chex.assert_rank(x, 2) # Verify number of dimensions
35
36
# Testing with variants - run same test jitted and non-jitted
37
class MyTest(chex.TestCase):
38
39
@chex.variants(with_jit=True, without_jit=True)
40
def test_my_function(self):
41
def my_fn(x):
42
return x + 1
43
44
fn = self.variant(my_fn) # Automatically jitted or not
45
result = fn(jnp.array([1, 2, 3]))
46
chex.assert_equal(result, jnp.array([2, 3, 4]))
47
48
# JAX-compatible dataclasses
49
@chex.dataclass
50
class Config:
51
learning_rate: float
52
batch_size: int
53
54
config = Config(learning_rate=0.01, batch_size=32)
55
# Works seamlessly with JAX transformations like jit, vmap, etc.
56
57
# Debugging utilities
58
with chex.fake_jit():
59
# JAX jit calls become identity functions for easier debugging
60
jitted_fn = jax.jit(lambda x: x * 2)
61
result = jitted_fn(5) # Executes without compilation
62
```
63
64
## Architecture
65
66
Chex is organized into functional modules that address different aspects of JAX development:
67
68
- **Assertions**: Comprehensive validation functions for shapes, values, types, and device placement
69
- **Testing**: Variant-based testing framework for running tests across different JAX execution modes
70
- **Dataclasses**: JAX-compatible dataclass implementation with PyTree integration
71
- **Type System**: Complete type definitions for JAX arrays, shapes, and computational primitives
72
- **Debugging**: Utilities to patch JAX functions for easier development and debugging workflows
73
- **Backend Management**: Tools for controlling and restricting JAX backend compilation
74
75
## Capabilities
76
77
### Assertion Functions
78
79
Comprehensive validation utilities for JAX computations including shape assertions, value comparisons, device placement checks, and tree structure validation. Essential for robust JAX code development and debugging.
80
81
```python { .api }
82
def assert_shape(array, expected_shape): ...
83
def assert_equal(first, second): ...
84
def assert_rank(array, expected_rank): ...
85
def assert_tree_all_finite(tree): ...
86
def assert_devices_available(devices): ...
87
```
88
89
[Assertions](./assertions.md)
90
91
### JAX-Compatible Dataclasses
92
93
JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations, providing structured data containers that integrate with the JAX ecosystem.
94
95
```python { .api }
96
def dataclass(cls, *, init=True, repr=True, eq=True, **kwargs): ...
97
def mappable_dataclass(cls): ...
98
def register_dataclass_type_with_jax_tree_util(cls): ...
99
```
100
101
[Dataclasses](./dataclasses.md)
102
103
### Test Variants and Testing Infrastructure
104
105
Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation.
106
107
```python { .api }
108
class TestCase: ...
109
def variants(*variant_types): ...
110
def all_variants(*variant_types): ...
111
class ChexVariantType(Enum): ...
112
def params_product(*params_lists, named=False): ...
113
```
114
115
[Testing](./testing.md)
116
117
### Type Definitions
118
119
Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.
120
121
```python { .api }
122
Array = Union[ArrayDevice, ArrayBatched, ArraySharded, ArrayNumpy, ...]
123
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
124
Scalar = Union[float, int]
125
Shape = Sequence[Union[int, Any]]
126
```
127
128
[Types](./types.md)
129
130
### Debugging and Development Utilities
131
132
Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.
133
134
```python { .api }
135
def fake_jit(fn, **kwargs): ...
136
def fake_pmap(fn, **kwargs): ...
137
def fake_pmap_and_jit(fn, **kwargs): ...
138
def set_n_cpu_devices(n): ...
139
```
140
141
[Debugging](./debugging.md)
142
143
### Advanced Features
144
145
Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.
146
147
```python { .api }
148
def restrict_backends(*, allowed=None, forbidden=None): ...
149
class Dimensions(**kwargs): ...
150
def chexify(fn, async_check=True, errors=...): ...
151
```
152
153
[Advanced](./advanced.md)