Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
A utility for validating neural network parameters represented as PyTree structures. The validator ensures that parameter trees contain valid numerical values and can compare parameter states across training checkpoints.
@generates
def validate_parameters_finite(params):
"""
Validates that all numerical values in a parameter PyTree are finite.
Args:
params: A PyTree (nested dict/list/tuple) containing JAX arrays
Raises:
AssertionError: If any non-finite values (NaN or infinity) are found
"""
pass
def compare_parameter_trees(params1, params2):
"""
Compares two parameter PyTrees for exact equality in structure and values.
Args:
params1: First PyTree containing JAX arrays
params2: Second PyTree containing JAX arrays
Raises:
AssertionError: If the trees differ in structure or values
"""
passProvides JAX testing and assertion utilities including tree validation functions.
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10