or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

Workspace
tessl
Visibility
Public
Created
Last updated
Describes
pypipkg:pypi/chex@0.1.x

To install, run

npx @tessl/cli install tessl/pypi-chex@0.1.0

0

# Chex

1

2

Chex is a comprehensive utilities library designed specifically for JAX development that provides essential tools for testing, debugging, and instrumentation. The library offers robust assertion functions for validating tensor shapes, dimensions, and properties in JAX computations, debugging utilities that can transform pmaps to vmaps for easier development workflows, and comprehensive testing infrastructure that enables running the same test code across multiple variants (jitted vs non-jitted, different devices).

3

4

## Package Information

5

6

- **Package Name**: chex

7

- **Language**: Python

8

- **Installation**: `pip install chex`

9

- **Requires**: Python >=3.11, JAX >=0.7.0

10

11

## Core Imports

12

13

```python

14

import chex

15

```

16

17

Individual modules can be imported as needed:

18

19

```python

20

# Core assertions and utilities

21

from chex import assert_shape, assert_equal, dataclass

22

from chex import fake_jit, fake_pmap, variants, TestCase

23

```

24

25

## Basic Usage

26

27

```python

28

import chex

29

import jax.numpy as jnp

30

31

# Shape and dimension assertions

32

x = jnp.array([[1, 2, 3], [4, 5, 6]])

33

chex.assert_shape(x, (2, 3)) # Verify expected shape

34

chex.assert_rank(x, 2) # Verify number of dimensions

35

36

# Testing with variants - run same test jitted and non-jitted

37

class MyTest(chex.TestCase):

38

39

@chex.variants(with_jit=True, without_jit=True)

40

def test_my_function(self):

41

def my_fn(x):

42

return x + 1

43

44

fn = self.variant(my_fn) # Automatically jitted or not

45

result = fn(jnp.array([1, 2, 3]))

46

chex.assert_equal(result, jnp.array([2, 3, 4]))

47

48

# JAX-compatible dataclasses

49

@chex.dataclass

50

class Config:

51

learning_rate: float

52

batch_size: int

53

54

config = Config(learning_rate=0.01, batch_size=32)

55

# Works seamlessly with JAX transformations like jit, vmap, etc.

56

57

# Debugging utilities

58

with chex.fake_jit():

59

# JAX jit calls become identity functions for easier debugging

60

jitted_fn = jax.jit(lambda x: x * 2)

61

result = jitted_fn(5) # Executes without compilation

62

```

63

64

## Architecture

65

66

Chex is organized into functional modules that address different aspects of JAX development:

67

68

- **Assertions**: Comprehensive validation functions for shapes, values, types, and device placement

69

- **Testing**: Variant-based testing framework for running tests across different JAX execution modes

70

- **Dataclasses**: JAX-compatible dataclass implementation with PyTree integration

71

- **Type System**: Complete type definitions for JAX arrays, shapes, and computational primitives

72

- **Debugging**: Utilities to patch JAX functions for easier development and debugging workflows

73

- **Backend Management**: Tools for controlling and restricting JAX backend compilation

74

75

## Capabilities

76

77

### Assertion Functions

78

79

Comprehensive validation utilities for JAX computations including shape assertions, value comparisons, device placement checks, and tree structure validation. Essential for robust JAX code development and debugging.

80

81

```python { .api }

82

def assert_shape(array, expected_shape): ...

83

def assert_equal(first, second): ...

84

def assert_rank(array, expected_rank): ...

85

def assert_tree_all_finite(tree): ...

86

def assert_devices_available(devices): ...

87

```

88

89

[Assertions](./assertions.md)

90

91

### JAX-Compatible Dataclasses

92

93

JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations, providing structured data containers that integrate with the JAX ecosystem.

94

95

```python { .api }

96

def dataclass(cls, *, init=True, repr=True, eq=True, **kwargs): ...

97

def mappable_dataclass(cls): ...

98

def register_dataclass_type_with_jax_tree_util(cls): ...

99

```

100

101

[Dataclasses](./dataclasses.md)

102

103

### Test Variants and Testing Infrastructure

104

105

Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation.

106

107

```python { .api }

108

class TestCase: ...

109

def variants(*variant_types): ...

110

def all_variants(*variant_types): ...

111

class ChexVariantType(Enum): ...

112

def params_product(*params_lists, named=False): ...

113

```

114

115

[Testing](./testing.md)

116

117

### Type Definitions

118

119

Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.

120

121

```python { .api }

122

Array = Union[ArrayDevice, ArrayBatched, ArraySharded, ArrayNumpy, ...]

123

ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]

124

Scalar = Union[float, int]

125

Shape = Sequence[Union[int, Any]]

126

```

127

128

[Types](./types.md)

129

130

### Debugging and Development Utilities

131

132

Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.

133

134

```python { .api }

135

def fake_jit(fn, **kwargs): ...

136

def fake_pmap(fn, **kwargs): ...

137

def fake_pmap_and_jit(fn, **kwargs): ...

138

def set_n_cpu_devices(n): ...

139

```

140

141

[Debugging](./debugging.md)

142

143

### Advanced Features

144

145

Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.

146

147

```python { .api }

148

def restrict_backends(*, allowed=None, forbidden=None): ...

149

class Dimensions(**kwargs): ...

150

def chexify(fn, async_check=True, errors=...): ...

151

```

152

153

[Advanced](./advanced.md)