Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
A configuration management system for neural network hyperparameters that integrates with JAX transformations.
Create a dataclass that stores neural network hyperparameters and can be used seamlessly with JAX tree operations.
NetworkConfig dataclass with fields hidden_dim (int), num_layers (int), dropout_rate (float), and learning_rate (float) can be created and is compatible with JAX tree utilities @testNetworkConfig dataclass can be flattened and unflattened using jax.tree_util.tree_flatten and jax.tree_util.tree_unflatten @testApply transformations to all numeric values in the configuration using JAX tree mapping.
NetworkConfig with values hidden_dim=128, num_layers=3, dropout_rate=0.1, learning_rate=0.001, applying jax.tree_map(lambda x: x * 2) doubles all numeric values @testEnable dictionary-like access patterns for configuration fields.
MappableNetworkConfig dataclass can be accessed using dictionary syntax (e.g., config['hidden_dim']) to retrieve field values @testMappableNetworkConfig supports iteration over keys using standard dictionary methods like .keys(), .values(), and .items() @testReconstruct a configuration dataclass from a tuple representation.
NetworkConfig instance can be created from a tuple (128, 3, 0.1, 0.001) using the from_tuple method, matching the order of fields @test@generates
"""Neural network configuration management."""
# Define your NetworkConfig dataclass here
# Define your MappableNetworkConfig dataclass hereProvides JAX-friendly dataclass decorators and testing utilities.
Install with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10