CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

task.mdevals/scenario-7/

Multi-Device Parallel Testing Framework

Build a testing utility that enables multi-device testing of parallel JAX computations in a local development environment.

Requirements

Your implementation should provide a test framework that:

  1. Device Environment Setup: Create a function setup_multi_device_test(n_devices) that configures the testing environment to simulate multiple computational devices using CPU threads. This allows testing parallel code without requiring actual multi-device hardware.

  2. Parallel Computation Testing: Implement a test function test_parallel_sum(arrays) that:

    • Takes a list of JAX arrays as input
    • Uses parallel mapping to compute the sum of each array across simulated devices
    • Returns a list of computed sums
    • The test should verify that parallel execution works correctly across the simulated devices
  3. Device Configuration Validation: Create a function get_device_count() that returns the number of available devices in the current JAX environment.

  4. Cleanup: Implement a teardown_multi_device_test() function to reset the device environment after testing.

Test Cases

  • Given setup_multi_device_test(4) is called, get_device_count() returns 4 @test
  • Given arrays [jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], test_parallel_sum returns [6, 15] when using 2 devices @test
  • After teardown_multi_device_test() is called following setup_multi_device_test(8), the device count returns to the original system default @test

Implementation

@generates

API

def setup_multi_device_test(n_devices: int) -> None:
    """Configure the testing environment to use n_devices simulated CPU devices."""
    pass

def get_device_count() -> int:
    """Return the number of available JAX devices."""
    pass

def test_parallel_sum(arrays: list) -> list:
    """Test parallel computation by summing arrays across devices."""
    pass

def teardown_multi_device_test() -> None:
    """Reset device environment to original state."""
    pass

Dependencies { .dependencies }

chex { .dependency }

Provides multi-device testing utilities for JAX code, including CPU device simulation capabilities.

@satisfied-by

Install with Tessl CLI

npx tessl i tessl/pypi-chex

tile.json