Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
Build a testing utility that enables multi-device testing of parallel JAX computations in a local development environment.
Your implementation should provide a test framework that:
Device Environment Setup: Create a function setup_multi_device_test(n_devices) that configures the testing environment to simulate multiple computational devices using CPU threads. This allows testing parallel code without requiring actual multi-device hardware.
Parallel Computation Testing: Implement a test function test_parallel_sum(arrays) that:
Device Configuration Validation: Create a function get_device_count() that returns the number of available devices in the current JAX environment.
Cleanup: Implement a teardown_multi_device_test() function to reset the device environment after testing.
setup_multi_device_test(4) is called, get_device_count() returns 4 @test[jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], test_parallel_sum returns [6, 15] when using 2 devices @testteardown_multi_device_test() is called following setup_multi_device_test(8), the device count returns to the original system default @test@generates
def setup_multi_device_test(n_devices: int) -> None:
"""Configure the testing environment to use n_devices simulated CPU devices."""
pass
def get_device_count() -> int:
"""Return the number of available JAX devices."""
pass
def test_parallel_sum(arrays: list) -> list:
"""Test parallel computation by summing arrays across devices."""
pass
def teardown_multi_device_test() -> None:
"""Reset device environment to original state."""
passProvides multi-device testing utilities for JAX code, including CPU device simulation capabilities.
@satisfied-by
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10