Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
A utility for validating that arrays in distributed computations are properly sharded across devices.
@generates
def validate_sharding(tree_of_arrays):
"""
Validates that all arrays in a PyTree structure are properly sharded across devices.
Args:
tree_of_arrays: A JAX array or a PyTree (nested structure) containing JAX arrays
Returns:
bool: True if all arrays are properly sharded
Raises:
AssertionError: If any array in the structure is not properly sharded
"""
passProvides JAX testing utilities including array sharding validation.
@satisfied-by
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10