0
# Test Variants and Testing Infrastructure
1
2
Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation of JAX code behavior.
3
4
## Capabilities
5
6
### Test Base Classes
7
8
Base test case class providing variant testing infrastructure.
9
10
```python { .api }
11
class TestCase(parameterized.TestCase):
12
"""
13
Base class for Chex tests that use variants.
14
15
Provides infrastructure for running tests across multiple JAX execution modes.
16
Subclasses from absl.testing.parameterized.TestCase to support generator unrolling.
17
"""
18
19
def variant(self, *args, **kwargs):
20
"""
21
Access the current test variant function.
22
23
This method is dynamically replaced by the @variants decorator
24
with the appropriate transformation (jit, identity, etc.).
25
26
Raises:
27
- RuntimeError: If called without @variants decorator
28
"""
29
```
30
31
### Variant Decorators
32
33
Decorators for running tests across multiple execution modes.
34
35
```python { .api }
36
def variants(*variant_types):
37
"""
38
Decorator to run test across specified variants.
39
40
Parameters:
41
- *variant_types: ChexVariantType values specifying which variants to test
42
43
Returns:
44
- Generator yielding one test per variant
45
46
Example:
47
@variants(ChexVariantType.WITH_JIT, ChexVariantType.WITHOUT_JIT)
48
def test_function(self):
49
fn = self.variant(my_function)
50
# Test implementation
51
"""
52
53
def all_variants(*variant_types):
54
"""
55
Decorator to run test across all available variants.
56
57
Parameters:
58
- *variant_types: Optional variant types to include (defaults to all)
59
60
Returns:
61
- Generator yielding one test per variant
62
"""
63
```
64
65
### Variant Types
66
67
Enumeration of available test variant types.
68
69
```python { .api }
70
class ChexVariantType(Enum):
71
"""
72
Enumeration of available Chex test variants.
73
74
Use self.variant.type to get the type of the current test variant.
75
"""
76
77
WITH_JIT = 1 # Function wrapped with jax.jit
78
WITHOUT_JIT = 2 # Function executed directly (identity)
79
WITH_DEVICE = 3 # Function executed on specific device
80
WITHOUT_DEVICE = 4 # Function executed on default device
81
WITH_PMAP = 5 # Function wrapped with jax.pmap
82
```
83
84
### Parameter Generation
85
86
Utilities for generating test parameter combinations.
87
88
```python { .api }
89
def params_product(*params_lists, named=False):
90
"""
91
Generate cartesian product of parameter lists for parameterized tests.
92
93
Parameters:
94
- *params_lists: Sequences of parameter values
95
- named: Whether to generate test names for parameterized.named_parameters
96
97
Returns:
98
- Sequence of parameter combinations
99
100
Example:
101
# Generate all combinations of batch sizes and learning rates
102
params = params_product([32, 64], [0.01, 0.001])
103
# [(32, 0.01), (32, 0.001), (64, 0.01), (64, 0.001)]
104
"""
105
```
106
107
## Usage Examples
108
109
### Basic Variant Testing
110
111
```python
112
import chex
113
import jax
114
import jax.numpy as jnp
115
116
class MyTest(chex.TestCase):
117
118
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
119
def test_addition_function(self):
120
def add_one(x):
121
return x + 1
122
123
# Get the variant-appropriate version of the function
124
fn = self.variant(add_one)
125
126
# Test the function
127
result = fn(jnp.array([1, 2, 3]))
128
expected = jnp.array([2, 3, 4])
129
130
chex.assert_equal(result, expected)
131
132
# Access variant type if needed
133
if self.variant.type == chex.ChexVariantType.WITH_JIT:
134
# This test is running with jit
135
pass
136
```
137
138
### Testing All Variants
139
140
```python
141
class ComprehensiveTest(chex.TestCase):
142
143
@chex.all_variants
144
def test_matrix_multiply(self):
145
def matmul(a, b):
146
return jnp.dot(a, b)
147
148
fn = self.variant(matmul)
149
150
a = jnp.array([[1, 2], [3, 4]])
151
b = jnp.array([[5, 6], [7, 8]])
152
153
result = fn(a, b)
154
expected = jnp.array([[19, 22], [43, 50]])
155
156
chex.assert_equal(result, expected)
157
```
158
159
### Parameterized Variant Testing
160
161
```python
162
from absl.testing import parameterized
163
164
class ParameterizedVariantTest(chex.TestCase):
165
166
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
167
@parameterized.parameters(
168
{'batch_size': 32, 'input_dim': 784},
169
{'batch_size': 64, 'input_dim': 1024},
170
)
171
def test_neural_network_layer(self, batch_size, input_dim):
172
def linear_layer(x, weights, bias):
173
return jnp.dot(x, weights) + bias
174
175
fn = self.variant(linear_layer)
176
177
# Create test data
178
x = jnp.ones((batch_size, input_dim))
179
weights = jnp.ones((input_dim, 10))
180
bias = jnp.zeros(10)
181
182
result = fn(x, weights, bias)
183
184
# Verify output shape
185
chex.assert_shape(result, (batch_size, 10))
186
187
# Verify computation
188
expected = jnp.full((batch_size, 10), input_dim)
189
chex.assert_equal(result, expected)
190
```
191
192
### Using Parameter Products
193
194
```python
195
class ProductTest(chex.TestCase):
196
197
# Generate all combinations of optimizers and learning rates
198
@parameterized.parameters(
199
*chex.params_product(
200
['sgd', 'adam', 'rmsprop'],
201
[0.1, 0.01, 0.001],
202
named=True
203
)
204
)
205
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
206
def test_optimizer_update(self, optimizer_name, learning_rate):
207
def update_step(params, grads, lr):
208
return params - lr * grads
209
210
fn = self.variant(update_step)
211
212
params = jnp.array([1.0, 2.0, 3.0])
213
grads = jnp.array([0.1, 0.2, 0.3])
214
215
updated_params = fn(params, grads, learning_rate)
216
expected = params - learning_rate * grads
217
218
chex.assert_trees_all_close(updated_params, expected)
219
```
220
221
### Testing with Device Variants
222
223
```python
224
class DeviceTest(chex.TestCase):
225
226
@chex.variants(
227
chex.ChexVariantType.WITH_DEVICE,
228
chex.ChexVariantType.WITHOUT_DEVICE
229
)
230
def test_device_placement(self):
231
def compute_sum(x):
232
return jnp.sum(x)
233
234
fn = self.variant(compute_sum)
235
236
x = jnp.array([1, 2, 3, 4, 5])
237
result = fn(x)
238
239
chex.assert_equal(result, 15)
240
241
# Can check device placement if needed
242
if hasattr(result, 'device'):
243
# Verify device placement based on variant type
244
pass
245
```
246
247
### Testing with Pmap Variants
248
249
```python
250
class PmapTest(chex.TestCase):
251
252
@chex.variants(chex.ChexVariantType.WITH_PMAP)
253
def test_parallel_computation(self):
254
def parallel_square(x):
255
return x ** 2
256
257
fn = self.variant(parallel_square)
258
259
# Create data for multiple devices
260
n_devices = jax.local_device_count()
261
x = jnp.arange(n_devices * 4).reshape(n_devices, 4)
262
263
result = fn(x)
264
expected = x ** 2
265
266
chex.assert_equal(result, expected)
267
```
268
269
### Advanced Variant Usage
270
271
```python
272
class AdvancedVariantTest(chex.TestCase):
273
274
def setUp(self):
275
super().setUp()
276
# Setup that runs before each variant
277
self.tolerance = 1e-6
278
279
@chex.all_variants
280
def test_gradient_computation(self):
281
def loss_fn(params, data):
282
return jnp.sum((params['w'] @ data - params['b']) ** 2)
283
284
# Get variant-appropriate version
285
loss_fn = self.variant(loss_fn)
286
287
# Create test data
288
params = {'w': jnp.array([[1.0, 2.0]]), 'b': jnp.array([0.5])}
289
data = jnp.array([[1.0], [2.0]])
290
291
# Compute gradients
292
grad_fn = jax.grad(loss_fn)
293
grads = grad_fn(params, data)
294
295
# Verify gradient structure matches params
296
chex.assert_trees_all_equal_structs(grads, params)
297
298
# Verify gradients are finite
299
chex.assert_tree_all_finite(grads)
300
301
def test_variant_type_specific_behavior(self):
302
"""Test that demonstrates variant-specific testing logic."""
303
304
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
305
def _test_impl(self):
306
def expensive_computation(x):
307
# Some computation that might behave differently jitted/non-jitted
308
return jnp.sum(jnp.sin(x) * jnp.cos(x))
309
310
fn = self.variant(expensive_computation)
311
x = jnp.linspace(0, 2 * jnp.pi, 1000)
312
result = fn(x)
313
314
# Different expectations based on variant type
315
if self.variant.type == chex.ChexVariantType.WITH_JIT:
316
# Jitted version might have slightly different numerical behavior
317
chex.assert_scalar(result)
318
else:
319
# Non-jitted version
320
chex.assert_scalar(result)
321
322
# Execute the test
323
_test_impl(self)
324
```
325
326
## Key Features
327
328
### Comprehensive Coverage
329
- Tests same logic across multiple execution modes
330
- Catches bugs that only appear in specific configurations
331
- Ensures consistent behavior between jitted and non-jitted code
332
333
### Easy Integration
334
- Drop-in replacement for standard test classes
335
- Works with existing parameterized testing frameworks
336
- Minimal changes to existing test code
337
338
### Flexible Configuration
339
- Choose specific variants or test all
340
- Combine with parameterized testing
341
- Support for device-specific testing
342
343
### Debugging Support
344
- Access to variant type within tests
345
- Clear error messages when variants fail
346
- Integration with Chex assertion framework
347
348
## Best Practices
349
350
### Use Meaningful Test Names
351
```python
352
@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
353
def test_neural_network_forward_pass_consistency(self):
354
# Clear test purpose
355
pass
356
```
357
358
### Test Critical Paths
359
```python
360
# Focus variant testing on functions that will be jitted in practice
361
@chex.all_variants
362
def test_training_step(self):
363
# This will be jitted in real usage
364
pass
365
```
366
367
### Combine with Assertions
368
```python
369
@chex.all_variants
370
def test_with_comprehensive_checks(self):
371
fn = self.variant(my_function)
372
result = fn(input_data)
373
374
# Use Chex assertions for thorough validation
375
chex.assert_shape(result, expected_shape)
376
chex.assert_tree_all_finite(result)
377
```