CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

testing.mddocs/

Test Variants and Testing Infrastructure

Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation of JAX code behavior.

Capabilities

Test Base Classes

Base test case class providing variant testing infrastructure.

class TestCase(parameterized.TestCase):
    """
    Base class for Chex tests that use variants.
    
    Provides infrastructure for running tests across multiple JAX execution modes.
    Subclasses from absl.testing.parameterized.TestCase to support generator unrolling.
    """
    
    def variant(self, *args, **kwargs):
        """
        Access the current test variant function.
        
        This method is dynamically replaced by the @variants decorator
        with the appropriate transformation (jit, identity, etc.).
        
        Raises:
        - RuntimeError: If called without @variants decorator
        """

Variant Decorators

Decorators for running tests across multiple execution modes.

def variants(*variant_types):
    """
    Decorator to run test across specified variants.
    
    Parameters:
    - *variant_types: ChexVariantType values specifying which variants to test
    
    Returns:
    - Generator yielding one test per variant
    
    Example:
    @variants(ChexVariantType.WITH_JIT, ChexVariantType.WITHOUT_JIT)
    def test_function(self):
        fn = self.variant(my_function)
        # Test implementation
    """

def all_variants(*variant_types):
    """
    Decorator to run test across all available variants.
    
    Parameters:
    - *variant_types: Optional variant types to include (defaults to all)
    
    Returns:
    - Generator yielding one test per variant
    """

Variant Types

Enumeration of available test variant types.

class ChexVariantType(Enum):
    """
    Enumeration of available Chex test variants.
    
    Use self.variant.type to get the type of the current test variant.
    """
    
    WITH_JIT = 1        # Function wrapped with jax.jit
    WITHOUT_JIT = 2     # Function executed directly (identity)
    WITH_DEVICE = 3     # Function executed on specific device
    WITHOUT_DEVICE = 4  # Function executed on default device
    WITH_PMAP = 5       # Function wrapped with jax.pmap

Parameter Generation

Utilities for generating test parameter combinations.

def params_product(*params_lists, named=False):
    """
    Generate cartesian product of parameter lists for parameterized tests.
    
    Parameters:
    - *params_lists: Sequences of parameter values
    - named: Whether to generate test names for parameterized.named_parameters
    
    Returns:
    - Sequence of parameter combinations
    
    Example:
    # Generate all combinations of batch sizes and learning rates
    params = params_product([32, 64], [0.01, 0.001])
    # [(32, 0.01), (32, 0.001), (64, 0.01), (64, 0.001)]
    """

Usage Examples

Basic Variant Testing

import chex
import jax
import jax.numpy as jnp

class MyTest(chex.TestCase):
    
    @chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
    def test_addition_function(self):
        def add_one(x):
            return x + 1
        
        # Get the variant-appropriate version of the function
        fn = self.variant(add_one)
        
        # Test the function
        result = fn(jnp.array([1, 2, 3]))
        expected = jnp.array([2, 3, 4])
        
        chex.assert_equal(result, expected)
        
        # Access variant type if needed
        if self.variant.type == chex.ChexVariantType.WITH_JIT:
            # This test is running with jit
            pass

Testing All Variants

class ComprehensiveTest(chex.TestCase):
    
    @chex.all_variants
    def test_matrix_multiply(self):
        def matmul(a, b):
            return jnp.dot(a, b)
        
        fn = self.variant(matmul)
        
        a = jnp.array([[1, 2], [3, 4]])
        b = jnp.array([[5, 6], [7, 8]])
        
        result = fn(a, b)
        expected = jnp.array([[19, 22], [43, 50]])
        
        chex.assert_equal(result, expected)

Parameterized Variant Testing

from absl.testing import parameterized

class ParameterizedVariantTest(chex.TestCase):
    
    @chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
    @parameterized.parameters(
        {'batch_size': 32, 'input_dim': 784},
        {'batch_size': 64, 'input_dim': 1024},
    )
    def test_neural_network_layer(self, batch_size, input_dim):
        def linear_layer(x, weights, bias):
            return jnp.dot(x, weights) + bias
        
        fn = self.variant(linear_layer)
        
        # Create test data
        x = jnp.ones((batch_size, input_dim))
        weights = jnp.ones((input_dim, 10))
        bias = jnp.zeros(10)
        
        result = fn(x, weights, bias)
        
        # Verify output shape
        chex.assert_shape(result, (batch_size, 10))
        
        # Verify computation
        expected = jnp.full((batch_size, 10), input_dim)
        chex.assert_equal(result, expected)

Using Parameter Products

class ProductTest(chex.TestCase):
    
    # Generate all combinations of optimizers and learning rates
    @parameterized.parameters(
        *chex.params_product(
            ['sgd', 'adam', 'rmsprop'],
            [0.1, 0.01, 0.001],
            named=True
        )
    )
    @chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
    def test_optimizer_update(self, optimizer_name, learning_rate):
        def update_step(params, grads, lr):
            return params - lr * grads
        
        fn = self.variant(update_step)
        
        params = jnp.array([1.0, 2.0, 3.0])
        grads = jnp.array([0.1, 0.2, 0.3])
        
        updated_params = fn(params, grads, learning_rate)
        expected = params - learning_rate * grads
        
        chex.assert_trees_all_close(updated_params, expected)

Testing with Device Variants

class DeviceTest(chex.TestCase):
    
    @chex.variants(
        chex.ChexVariantType.WITH_DEVICE,
        chex.ChexVariantType.WITHOUT_DEVICE
    )
    def test_device_placement(self):
        def compute_sum(x):
            return jnp.sum(x)
        
        fn = self.variant(compute_sum)
        
        x = jnp.array([1, 2, 3, 4, 5])
        result = fn(x)
        
        chex.assert_equal(result, 15)
        
        # Can check device placement if needed
        if hasattr(result, 'device'):
            # Verify device placement based on variant type
            pass

Testing with Pmap Variants

class PmapTest(chex.TestCase):
    
    @chex.variants(chex.ChexVariantType.WITH_PMAP)
    def test_parallel_computation(self):
        def parallel_square(x):
            return x ** 2
        
        fn = self.variant(parallel_square)
        
        # Create data for multiple devices
        n_devices = jax.local_device_count()
        x = jnp.arange(n_devices * 4).reshape(n_devices, 4)
        
        result = fn(x)
        expected = x ** 2
        
        chex.assert_equal(result, expected)

Advanced Variant Usage

class AdvancedVariantTest(chex.TestCase):
    
    def setUp(self):
        super().setUp()
        # Setup that runs before each variant
        self.tolerance = 1e-6
    
    @chex.all_variants
    def test_gradient_computation(self):
        def loss_fn(params, data):
            return jnp.sum((params['w'] @ data - params['b']) ** 2)
        
        # Get variant-appropriate version
        loss_fn = self.variant(loss_fn)
        
        # Create test data
        params = {'w': jnp.array([[1.0, 2.0]]), 'b': jnp.array([0.5])}
        data = jnp.array([[1.0], [2.0]])
        
        # Compute gradients
        grad_fn = jax.grad(loss_fn)
        grads = grad_fn(params, data)
        
        # Verify gradient structure matches params
        chex.assert_trees_all_equal_structs(grads, params)
        
        # Verify gradients are finite
        chex.assert_tree_all_finite(grads)
    
    def test_variant_type_specific_behavior(self):
        """Test that demonstrates variant-specific testing logic."""
        
        @chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
        def _test_impl(self):
            def expensive_computation(x):
                # Some computation that might behave differently jitted/non-jitted
                return jnp.sum(jnp.sin(x) * jnp.cos(x))
            
            fn = self.variant(expensive_computation)
            x = jnp.linspace(0, 2 * jnp.pi, 1000)
            result = fn(x)
            
            # Different expectations based on variant type
            if self.variant.type == chex.ChexVariantType.WITH_JIT:
                # Jitted version might have slightly different numerical behavior
                chex.assert_scalar(result)
            else:
                # Non-jitted version
                chex.assert_scalar(result)
        
        # Execute the test
        _test_impl(self)

Key Features

Comprehensive Coverage

  • Tests same logic across multiple execution modes
  • Catches bugs that only appear in specific configurations
  • Ensures consistent behavior between jitted and non-jitted code

Easy Integration

  • Drop-in replacement for standard test classes
  • Works with existing parameterized testing frameworks
  • Minimal changes to existing test code

Flexible Configuration

  • Choose specific variants or test all
  • Combine with parameterized testing
  • Support for device-specific testing

Debugging Support

  • Access to variant type within tests
  • Clear error messages when variants fail
  • Integration with Chex assertion framework

Best Practices

Use Meaningful Test Names

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)
def test_neural_network_forward_pass_consistency(self):
    # Clear test purpose
    pass

Test Critical Paths

# Focus variant testing on functions that will be jitted in practice
@chex.all_variants
def test_training_step(self):
    # This will be jitted in real usage
    pass

Combine with Assertions

@chex.all_variants
def test_with_comprehensive_checks(self):
    fn = self.variant(my_function)
    result = fn(input_data)
    
    # Use Chex assertions for thorough validation
    chex.assert_shape(result, expected_shape)
    chex.assert_tree_all_finite(result)

Install with Tessl CLI

npx tessl i tessl/pypi-chex

docs

advanced.md

assertions.md

dataclasses.md

debugging.md

index.md

testing.md

types.md

tile.json