Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Comprehensive validation utilities for JAX computations. These functions provide essential testing and debugging capabilities for validating tensor properties, shapes, values, and computational correctness in JAX programs.
Functions for validating array shapes, dimensions, and structural properties.
def assert_shape(array, expected_shape):
"""
Assert that array has the expected shape.
Parameters:
- array: Array to check
- expected_shape: Expected shape tuple, supports None for wildcard dimensions
"""
def assert_rank(array, expected_rank):
"""
Assert that array has the expected number of dimensions.
Parameters:
- array: Array to check
- expected_rank: Expected number of dimensions (int)
"""
def assert_size(array, expected_size):
"""
Assert that array has the expected total size.
Parameters:
- array: Array to check
- expected_size: Expected total number of elements (int)
"""
def assert_equal_shape(inputs, *, dims=None):
"""
Assert that all arrays have the same shape.
Parameters:
- inputs: Sequence of arrays to compare
- dims: Optional int or sequence of ints specifying which dimensions to compare
"""
def assert_equal_rank(inputs):
"""
Assert that all arrays have the same rank (number of dimensions).
Parameters:
- inputs: Sequence of arrays to compare
"""
def assert_equal_size(inputs):
"""
Assert that all arrays have the same total size.
Parameters:
- inputs: Sequence of arrays to compare
"""
def assert_equal_shape_prefix(inputs, prefix_len):
"""
Assert that the leading prefix_len dimensions of all inputs have same shape.
Parameters:
- inputs: Sequence of arrays to compare
- prefix_len: Number of leading dimensions to compare
"""
def assert_equal_shape_suffix(inputs, suffix_len):
"""
Assert that the final suffix_len dimensions of all inputs have same shape.
Parameters:
- inputs: Sequence of arrays to compare
- suffix_len: Number of trailing dimensions to compare
"""Functions for validating specific axis dimensions with comparison operators.
def assert_axis_dimension(tensor, axis, expected):
"""
Assert that a specific axis has the expected dimension size.
Parameters:
- tensor: Array to check
- axis: Axis index to check
- expected: Expected dimension size for the axis
"""
def assert_axis_dimension_comparator(tensor, axis, pass_fn, error_string):
"""
Assert that pass_fn(tensor.shape[axis]) passes.
Used to implement ==, >, >=, <, <= checks.
Parameters:
- tensor: JAX array to check
- axis: Axis index to check
- pass_fn: Function that takes dimension size and returns bool
- error_string: Error message to display if assertion fails
"""
def assert_axis_dimension_gt(tensor, axis, val):
"""
Assert that axis dimension is greater than the given value.
Parameters:
- tensor: Array to check
- axis: Axis index to check
- val: Minimum size (exclusive)
"""
def assert_axis_dimension_gteq(tensor, axis, val):
"""
Assert that axis dimension is greater than or equal to the given value.
Parameters:
- tensor: Array to check
- axis: Axis index to check
- val: Minimum size (inclusive)
"""
def assert_axis_dimension_lt(tensor, axis, val):
"""
Assert that axis dimension is less than the given value.
Parameters:
- tensor: Array to check
- axis: Axis index to check
- val: Maximum size (exclusive)
"""
def assert_axis_dimension_lteq(tensor, axis, val):
"""
Assert that axis dimension is less than or equal to the given value.
Parameters:
- tensor: Array to check
- axis: Axis index to check
- val: Maximum size (inclusive)
"""Functions for validating array values and content properties.
def assert_equal(first, second):
"""
Assert that two objects are equal as determined by the == operator.
Arrays with more than one element cannot be compared.
Use assert_trees_all_close to compare arrays.
Parameters:
- first: First object to compare
- second: Second object to compare
"""
def assert_scalar(value):
"""
Assert that value is a scalar (rank-0 array or Python scalar).
Parameters:
- value: Value to check
"""
def assert_scalar_in(value, options):
"""
Assert that scalar value is one of the given options.
Parameters:
- value: Scalar value to check
- options: Iterable of valid options
"""
def assert_scalar_positive(value):
"""
Assert that scalar value is positive (> 0).
Parameters:
- value: Scalar value to check
"""
def assert_scalar_non_negative(value):
"""
Assert that scalar value is non-negative (>= 0).
Parameters:
- value: Scalar value to check
"""
def assert_scalar_negative(value):
"""
Assert that scalar value is negative (< 0).
Parameters:
- value: Scalar value to check
"""
def assert_type(value, expected_type):
"""
Assert that value is of the expected type.
Parameters:
- value: Value to check
- expected_type: Expected type or tuple of types
"""Functions for validating JAX pytree structures and their properties.
def assert_tree_shape(tree, expected_shape):
"""
Assert that all arrays in the tree have the expected shape.
Parameters:
- tree: JAX pytree containing arrays
- expected_shape: Expected shape for all arrays in tree
"""
def assert_tree_shape_prefix(tree, prefix_shape):
"""
Assert that all arrays in tree have shapes starting with given prefix.
Parameters:
- tree: JAX pytree containing arrays
- prefix_shape: Shape prefix that all arrays should have
"""
def assert_tree_shape_suffix(tree, suffix_shape):
"""
Assert that all arrays in tree have shapes ending with given suffix.
Parameters:
- tree: JAX pytree containing arrays
- suffix_shape: Shape suffix that all arrays should have
"""
def assert_tree_all_finite(tree):
"""
Assert that all values in the tree are finite (not NaN or infinite).
Parameters:
- tree: JAX pytree containing arrays
"""
def assert_tree_has_only_ndarrays(tree):
"""
Assert that tree contains only numpy/JAX arrays.
Parameters:
- tree: JAX pytree to check
"""
def assert_tree_no_nones(tree):
"""
Assert that tree contains no None values.
Parameters:
- tree: JAX pytree to check
"""
def assert_tree_is_on_device(tree, device):
"""
Assert that all arrays in tree are on the specified device.
Parameters:
- tree: JAX pytree containing arrays
- device: Expected device
"""
def assert_tree_is_on_host(tree):
"""
Assert that all arrays in tree are on host (CPU).
Parameters:
- tree: JAX pytree containing arrays
"""
def assert_tree_is_sharded(tree):
"""
Assert that tree contains sharded arrays.
Parameters:
- tree: JAX pytree containing arrays
"""Functions for comparing multiple JAX pytrees.
def assert_trees_all_equal(*trees):
"""
Assert that all trees are exactly equal in structure and values.
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_equal_comparator(tree1, tree2, comparator):
"""
Assert that two trees are equal using a custom comparator function.
Parameters:
- tree1, tree2: JAX pytrees to compare
- comparator: Function to compare individual array elements
"""
def assert_trees_all_equal_dtypes(*trees):
"""
Assert that all trees have matching data types.
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_equal_shapes(*trees):
"""
Assert that all trees have matching shapes.
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_equal_shapes_and_dtypes(*trees):
"""
Assert that all trees have matching shapes and data types.
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_equal_sizes(*trees):
"""
Assert that all trees have matching sizes.
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_equal_structs(*trees):
"""
Assert that all trees have matching structures (ignoring values).
Parameters:
- *trees: Variable number of JAX pytrees to compare
"""
def assert_trees_all_close(tree1, tree2, rtol=1e-05, atol=1e-08):
"""
Assert that trees are numerically close within tolerance.
Parameters:
- tree1, tree2: JAX pytrees to compare
- rtol: Relative tolerance
- atol: Absolute tolerance
"""
def assert_trees_all_close_ulp(tree1, tree2, maxulp=4):
"""
Assert that trees are close within Units in the Last Place tolerance.
Parameters:
- tree1, tree2: JAX pytrees to compare
- maxulp: Maximum units in the last place difference allowed
"""Functions for validating device availability and placement.
def assert_devices_available(devices):
"""
Assert that specified devices are available.
Parameters:
- devices: List of device specifications or device objects
"""
def assert_gpu_available():
"""
Assert that at least one GPU device is available.
"""
def assert_tpu_available():
"""
Assert that at least one TPU device is available.
"""Helper functions for common validation patterns.
def assert_exactly_one_is_none(*values):
"""
Assert that exactly one of the given values is None.
Parameters:
- *values: Variable number of values to check
"""
def assert_not_both_none(value1, value2):
"""
Assert that at least one of the two values is not None.
Parameters:
- value1, value2: Values to check
"""
def assert_is_broadcastable(shape1, shape2):
"""
Assert that two shapes are broadcastable according to NumPy rules.
Parameters:
- shape1, shape2: Shape tuples to check
"""
def assert_is_divisible(dividend, divisor):
"""
Assert that dividend is evenly divisible by divisor.
Parameters:
- dividend: Number to divide
- divisor: Number to divide by
"""
def assert_numerical_grads(fn, args, order=1, **kwargs):
"""
Assert that analytical gradients match numerical gradients.
Parameters:
- fn: Function to test gradients for
- args: Arguments to pass to function
- order: Order of derivative to test
- **kwargs: Additional arguments for numerical gradient computation
"""Functions for controlling assertion behavior globally.
def enable_asserts():
"""
Enable all Chex assertions (default state).
"""
def disable_asserts():
"""
Disable all Chex assertions for performance.
"""
def if_args_not_none(fn, *args, **kwargs):
"""
Execute assertion function only if all positional arguments are not None.
Parameters:
- fn: Assertion function to conditionally execute
- *args: Arguments to pass to fn
- **kwargs: Keyword arguments to pass to fn
"""
def clear_trace_counter():
"""
Clear the trace counter used by assert_max_traces.
"""
def assert_max_traces(fn, n):
"""
Decorator/wrapper to assert function is traced at most n times.
Parameters:
- fn: Function to wrap or n (number of max traces) if used as decorator
- n: Maximum number of traces allowed (if fn is a function)
Returns:
- Wrapped function or decorator
"""import chex
import jax.numpy as jnp
# Create test arrays
x = jnp.array([[1, 2, 3], [4, 5, 6]]) # Shape: (2, 3)
y = jnp.zeros((2, 3))
# Validate shapes
chex.assert_shape(x, (2, 3)) # Passes
chex.assert_equal_shape([x, y]) # Passes - note list of arrays
chex.assert_rank(x, 2) # Passes
# Wildcard dimensions
z = jnp.ones((2, 5))
chex.assert_shape(z, (2, None)) # Passes - None matches any size# Create a pytree
tree = {
'weights': jnp.array([[1, 2], [3, 4]]),
'bias': jnp.array([0.1, 0.2]),
'nested': {'param': jnp.array([1.0])}
}
# Validate tree properties
chex.assert_tree_all_finite(tree)
chex.assert_tree_has_only_ndarrays(tree)
# Compare trees
tree2 = jax.tree_map(lambda x: x + 0.01, tree)
chex.assert_trees_all_close(tree, tree2, atol=0.02)def process_data(data, weights=None):
chex.assert_shape(data, (None, 10)) # Any batch size, 10 features
# Only check weights if provided
chex.if_args_not_none(chex.assert_shape, weights, (10, 5))
return data @ weights if weights is not None else dataInstall with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10