0
# Assertion Functions
1
2
Comprehensive validation utilities for JAX computations. These functions provide essential testing and debugging capabilities for validating tensor properties, shapes, values, and computational correctness in JAX programs.
3
4
## Capabilities
5
6
### Shape and Dimension Assertions
7
8
Functions for validating array shapes, dimensions, and structural properties.
9
10
```python { .api }
11
def assert_shape(array, expected_shape):
12
"""
13
Assert that array has the expected shape.
14
15
Parameters:
16
- array: Array to check
17
- expected_shape: Expected shape tuple, supports None for wildcard dimensions
18
"""
19
20
def assert_rank(array, expected_rank):
21
"""
22
Assert that array has the expected number of dimensions.
23
24
Parameters:
25
- array: Array to check
26
- expected_rank: Expected number of dimensions (int)
27
"""
28
29
def assert_size(array, expected_size):
30
"""
31
Assert that array has the expected total size.
32
33
Parameters:
34
- array: Array to check
35
- expected_size: Expected total number of elements (int)
36
"""
37
38
def assert_equal_shape(inputs, *, dims=None):
39
"""
40
Assert that all arrays have the same shape.
41
42
Parameters:
43
- inputs: Sequence of arrays to compare
44
- dims: Optional int or sequence of ints specifying which dimensions to compare
45
"""
46
47
def assert_equal_rank(inputs):
48
"""
49
Assert that all arrays have the same rank (number of dimensions).
50
51
Parameters:
52
- inputs: Sequence of arrays to compare
53
"""
54
55
def assert_equal_size(inputs):
56
"""
57
Assert that all arrays have the same total size.
58
59
Parameters:
60
- inputs: Sequence of arrays to compare
61
"""
62
63
def assert_equal_shape_prefix(inputs, prefix_len):
64
"""
65
Assert that the leading prefix_len dimensions of all inputs have same shape.
66
67
Parameters:
68
- inputs: Sequence of arrays to compare
69
- prefix_len: Number of leading dimensions to compare
70
"""
71
72
def assert_equal_shape_suffix(inputs, suffix_len):
73
"""
74
Assert that the final suffix_len dimensions of all inputs have same shape.
75
76
Parameters:
77
- inputs: Sequence of arrays to compare
78
- suffix_len: Number of trailing dimensions to compare
79
"""
80
```
81
82
### Axis-Specific Assertions
83
84
Functions for validating specific axis dimensions with comparison operators.
85
86
```python { .api }
87
def assert_axis_dimension(tensor, axis, expected):
88
"""
89
Assert that a specific axis has the expected dimension size.
90
91
Parameters:
92
- tensor: Array to check
93
- axis: Axis index to check
94
- expected: Expected dimension size for the axis
95
"""
96
97
def assert_axis_dimension_comparator(tensor, axis, pass_fn, error_string):
98
"""
99
Assert that pass_fn(tensor.shape[axis]) passes.
100
101
Used to implement ==, >, >=, <, <= checks.
102
103
Parameters:
104
- tensor: JAX array to check
105
- axis: Axis index to check
106
- pass_fn: Function that takes dimension size and returns bool
107
- error_string: Error message to display if assertion fails
108
"""
109
110
def assert_axis_dimension_gt(tensor, axis, val):
111
"""
112
Assert that axis dimension is greater than the given value.
113
114
Parameters:
115
- tensor: Array to check
116
- axis: Axis index to check
117
- val: Minimum size (exclusive)
118
"""
119
120
def assert_axis_dimension_gteq(tensor, axis, val):
121
"""
122
Assert that axis dimension is greater than or equal to the given value.
123
124
Parameters:
125
- tensor: Array to check
126
- axis: Axis index to check
127
- val: Minimum size (inclusive)
128
"""
129
130
def assert_axis_dimension_lt(tensor, axis, val):
131
"""
132
Assert that axis dimension is less than the given value.
133
134
Parameters:
135
- tensor: Array to check
136
- axis: Axis index to check
137
- val: Maximum size (exclusive)
138
"""
139
140
def assert_axis_dimension_lteq(tensor, axis, val):
141
"""
142
Assert that axis dimension is less than or equal to the given value.
143
144
Parameters:
145
- tensor: Array to check
146
- axis: Axis index to check
147
- val: Maximum size (inclusive)
148
"""
149
```
150
151
### Value and Content Assertions
152
153
Functions for validating array values and content properties.
154
155
```python { .api }
156
def assert_equal(first, second):
157
"""
158
Assert that two objects are equal as determined by the == operator.
159
160
Arrays with more than one element cannot be compared.
161
Use assert_trees_all_close to compare arrays.
162
163
Parameters:
164
- first: First object to compare
165
- second: Second object to compare
166
"""
167
168
def assert_scalar(value):
169
"""
170
Assert that value is a scalar (rank-0 array or Python scalar).
171
172
Parameters:
173
- value: Value to check
174
"""
175
176
def assert_scalar_in(value, options):
177
"""
178
Assert that scalar value is one of the given options.
179
180
Parameters:
181
- value: Scalar value to check
182
- options: Iterable of valid options
183
"""
184
185
def assert_scalar_positive(value):
186
"""
187
Assert that scalar value is positive (> 0).
188
189
Parameters:
190
- value: Scalar value to check
191
"""
192
193
def assert_scalar_non_negative(value):
194
"""
195
Assert that scalar value is non-negative (>= 0).
196
197
Parameters:
198
- value: Scalar value to check
199
"""
200
201
def assert_scalar_negative(value):
202
"""
203
Assert that scalar value is negative (< 0).
204
205
Parameters:
206
- value: Scalar value to check
207
"""
208
209
def assert_type(value, expected_type):
210
"""
211
Assert that value is of the expected type.
212
213
Parameters:
214
- value: Value to check
215
- expected_type: Expected type or tuple of types
216
"""
217
```
218
219
### Tree Structure Assertions
220
221
Functions for validating JAX pytree structures and their properties.
222
223
```python { .api }
224
def assert_tree_shape(tree, expected_shape):
225
"""
226
Assert that all arrays in the tree have the expected shape.
227
228
Parameters:
229
- tree: JAX pytree containing arrays
230
- expected_shape: Expected shape for all arrays in tree
231
"""
232
233
def assert_tree_shape_prefix(tree, prefix_shape):
234
"""
235
Assert that all arrays in tree have shapes starting with given prefix.
236
237
Parameters:
238
- tree: JAX pytree containing arrays
239
- prefix_shape: Shape prefix that all arrays should have
240
"""
241
242
def assert_tree_shape_suffix(tree, suffix_shape):
243
"""
244
Assert that all arrays in tree have shapes ending with given suffix.
245
246
Parameters:
247
- tree: JAX pytree containing arrays
248
- suffix_shape: Shape suffix that all arrays should have
249
"""
250
251
def assert_tree_all_finite(tree):
252
"""
253
Assert that all values in the tree are finite (not NaN or infinite).
254
255
Parameters:
256
- tree: JAX pytree containing arrays
257
"""
258
259
def assert_tree_has_only_ndarrays(tree):
260
"""
261
Assert that tree contains only numpy/JAX arrays.
262
263
Parameters:
264
- tree: JAX pytree to check
265
"""
266
267
def assert_tree_no_nones(tree):
268
"""
269
Assert that tree contains no None values.
270
271
Parameters:
272
- tree: JAX pytree to check
273
"""
274
275
def assert_tree_is_on_device(tree, device):
276
"""
277
Assert that all arrays in tree are on the specified device.
278
279
Parameters:
280
- tree: JAX pytree containing arrays
281
- device: Expected device
282
"""
283
284
def assert_tree_is_on_host(tree):
285
"""
286
Assert that all arrays in tree are on host (CPU).
287
288
Parameters:
289
- tree: JAX pytree containing arrays
290
"""
291
292
def assert_tree_is_sharded(tree):
293
"""
294
Assert that tree contains sharded arrays.
295
296
Parameters:
297
- tree: JAX pytree containing arrays
298
"""
299
```
300
301
### Multi-Tree Comparisons
302
303
Functions for comparing multiple JAX pytrees.
304
305
```python { .api }
306
def assert_trees_all_equal(*trees):
307
"""
308
Assert that all trees are exactly equal in structure and values.
309
310
Parameters:
311
- *trees: Variable number of JAX pytrees to compare
312
"""
313
314
def assert_trees_all_equal_comparator(tree1, tree2, comparator):
315
"""
316
Assert that two trees are equal using a custom comparator function.
317
318
Parameters:
319
- tree1, tree2: JAX pytrees to compare
320
- comparator: Function to compare individual array elements
321
"""
322
323
def assert_trees_all_equal_dtypes(*trees):
324
"""
325
Assert that all trees have matching data types.
326
327
Parameters:
328
- *trees: Variable number of JAX pytrees to compare
329
"""
330
331
def assert_trees_all_equal_shapes(*trees):
332
"""
333
Assert that all trees have matching shapes.
334
335
Parameters:
336
- *trees: Variable number of JAX pytrees to compare
337
"""
338
339
def assert_trees_all_equal_shapes_and_dtypes(*trees):
340
"""
341
Assert that all trees have matching shapes and data types.
342
343
Parameters:
344
- *trees: Variable number of JAX pytrees to compare
345
"""
346
347
def assert_trees_all_equal_sizes(*trees):
348
"""
349
Assert that all trees have matching sizes.
350
351
Parameters:
352
- *trees: Variable number of JAX pytrees to compare
353
"""
354
355
def assert_trees_all_equal_structs(*trees):
356
"""
357
Assert that all trees have matching structures (ignoring values).
358
359
Parameters:
360
- *trees: Variable number of JAX pytrees to compare
361
"""
362
363
def assert_trees_all_close(tree1, tree2, rtol=1e-05, atol=1e-08):
364
"""
365
Assert that trees are numerically close within tolerance.
366
367
Parameters:
368
- tree1, tree2: JAX pytrees to compare
369
- rtol: Relative tolerance
370
- atol: Absolute tolerance
371
"""
372
373
def assert_trees_all_close_ulp(tree1, tree2, maxulp=4):
374
"""
375
Assert that trees are close within Units in the Last Place tolerance.
376
377
Parameters:
378
- tree1, tree2: JAX pytrees to compare
379
- maxulp: Maximum units in the last place difference allowed
380
"""
381
```
382
383
### Device and Hardware Assertions
384
385
Functions for validating device availability and placement.
386
387
```python { .api }
388
def assert_devices_available(devices):
389
"""
390
Assert that specified devices are available.
391
392
Parameters:
393
- devices: List of device specifications or device objects
394
"""
395
396
def assert_gpu_available():
397
"""
398
Assert that at least one GPU device is available.
399
"""
400
401
def assert_tpu_available():
402
"""
403
Assert that at least one TPU device is available.
404
"""
405
```
406
407
### Utility Assertions
408
409
Helper functions for common validation patterns.
410
411
```python { .api }
412
def assert_exactly_one_is_none(*values):
413
"""
414
Assert that exactly one of the given values is None.
415
416
Parameters:
417
- *values: Variable number of values to check
418
"""
419
420
def assert_not_both_none(value1, value2):
421
"""
422
Assert that at least one of the two values is not None.
423
424
Parameters:
425
- value1, value2: Values to check
426
"""
427
428
def assert_is_broadcastable(shape1, shape2):
429
"""
430
Assert that two shapes are broadcastable according to NumPy rules.
431
432
Parameters:
433
- shape1, shape2: Shape tuples to check
434
"""
435
436
def assert_is_divisible(dividend, divisor):
437
"""
438
Assert that dividend is evenly divisible by divisor.
439
440
Parameters:
441
- dividend: Number to divide
442
- divisor: Number to divide by
443
"""
444
445
def assert_numerical_grads(fn, args, order=1, **kwargs):
446
"""
447
Assert that analytical gradients match numerical gradients.
448
449
Parameters:
450
- fn: Function to test gradients for
451
- args: Arguments to pass to function
452
- order: Order of derivative to test
453
- **kwargs: Additional arguments for numerical gradient computation
454
"""
455
```
456
457
### Assertion Control
458
459
Functions for controlling assertion behavior globally.
460
461
```python { .api }
462
def enable_asserts():
463
"""
464
Enable all Chex assertions (default state).
465
"""
466
467
def disable_asserts():
468
"""
469
Disable all Chex assertions for performance.
470
"""
471
472
def if_args_not_none(fn, *args, **kwargs):
473
"""
474
Execute assertion function only if all positional arguments are not None.
475
476
Parameters:
477
- fn: Assertion function to conditionally execute
478
- *args: Arguments to pass to fn
479
- **kwargs: Keyword arguments to pass to fn
480
"""
481
482
def clear_trace_counter():
483
"""
484
Clear the trace counter used by assert_max_traces.
485
"""
486
487
def assert_max_traces(fn, n):
488
"""
489
Decorator/wrapper to assert function is traced at most n times.
490
491
Parameters:
492
- fn: Function to wrap or n (number of max traces) if used as decorator
493
- n: Maximum number of traces allowed (if fn is a function)
494
495
Returns:
496
- Wrapped function or decorator
497
"""
498
```
499
500
## Usage Examples
501
502
### Basic Shape Validation
503
504
```python
505
import chex
506
import jax.numpy as jnp
507
508
# Create test arrays
509
x = jnp.array([[1, 2, 3], [4, 5, 6]]) # Shape: (2, 3)
510
y = jnp.zeros((2, 3))
511
512
# Validate shapes
513
chex.assert_shape(x, (2, 3)) # Passes
514
chex.assert_equal_shape([x, y]) # Passes - note list of arrays
515
chex.assert_rank(x, 2) # Passes
516
517
# Wildcard dimensions
518
z = jnp.ones((2, 5))
519
chex.assert_shape(z, (2, None)) # Passes - None matches any size
520
```
521
522
### Tree Validation
523
524
```python
525
# Create a pytree
526
tree = {
527
'weights': jnp.array([[1, 2], [3, 4]]),
528
'bias': jnp.array([0.1, 0.2]),
529
'nested': {'param': jnp.array([1.0])}
530
}
531
532
# Validate tree properties
533
chex.assert_tree_all_finite(tree)
534
chex.assert_tree_has_only_ndarrays(tree)
535
536
# Compare trees
537
tree2 = jax.tree_map(lambda x: x + 0.01, tree)
538
chex.assert_trees_all_close(tree, tree2, atol=0.02)
539
```
540
541
### Conditional Assertions
542
543
```python
544
def process_data(data, weights=None):
545
chex.assert_shape(data, (None, 10)) # Any batch size, 10 features
546
547
# Only check weights if provided
548
chex.if_args_not_none(chex.assert_shape, weights, (10, 5))
549
550
return data @ weights if weights is not None else data
551
```