Comprehensive utilities library for JAX testing, debugging, and instrumentation
73
JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations. Chex dataclasses are automatically registered as JAX pytrees and can be used with all JAX transformations like jit, vmap, grad, etc.
JAX-compatible dataclass decorator that creates structured data containers working with JAX ecosystem.
def dataclass(
cls=None,
*,
init=True,
repr=True,
eq=True,
order=False,
unsafe_hash=False,
frozen=False,
mappable=False
):
"""
JAX-compatible dataclass decorator.
Parameters:
- cls: Class to decorate (when used without parentheses)
- init: Generate __init__ method
- repr: Generate __repr__ method
- eq: Generate __eq__ method
- order: Generate ordering methods (__lt__, __le__, __gt__, __ge__)
- unsafe_hash: Generate __hash__ method (use with caution)
- frozen: Make instances immutable
- mappable: Make dataclass compatible with collections.abc.Mapping
Returns:
- Decorated class registered as JAX pytree
"""
def mappable_dataclass(cls):
"""
Make dataclass compatible with collections.abc.Mapping interface.
Allows dataclass instances to be used with dm-tree library and provides
dict-like access patterns. Changes constructor to dict-style (no positional args).
Parameters:
- cls: A dataclass to make mappable
Returns:
- Modified dataclass implementing collections.abc.Mapping
Raises:
- ValueError: If cls is not a dataclass
"""
def register_dataclass_type_with_jax_tree_util(cls):
"""
Manually register a dataclass type with JAX tree utilities.
Normally done automatically by @chex.dataclass, but can be called
manually for dataclasses created with other decorators.
Parameters:
- cls: Dataclass type to register
"""Exception types for dataclass operations.
FrozenInstanceError = dataclasses.FrozenInstanceErrorimport chex
import jax
import jax.numpy as jnp
@chex.dataclass
class Config:
learning_rate: float
batch_size: int
hidden_dims: tuple
# Create instance
config = Config(
learning_rate=0.01,
batch_size=32,
hidden_dims=(128, 64)
)
# Works with JAX transformations
def compute_loss(config, data):
# Use config parameters in computation
return jnp.sum(data) * config.learning_rate
# Can be passed through jit, vmap, etc.
jitted_loss = jax.jit(compute_loss)
result = jitted_loss(config, jnp.array([1.0, 2.0, 3.0]))@chex.dataclass(frozen=True)
class ImmutableConfig:
model_name: str
version: int
config = ImmutableConfig(model_name="transformer", version=1)
# This would raise FrozenInstanceError
# config.version = 2 # Error!
# Use replace() to create modified copies
new_config = config.replace(version=2)@chex.mappable_dataclass
@chex.dataclass
class Parameters:
weights: jnp.ndarray
bias: jnp.ndarray
scale: float = 1.0
# Can be created dict-style (no positional args)
params = Parameters({
'weights': jnp.ones((10, 5)),
'bias': jnp.zeros(5),
'scale': 0.5
})
# Supports dict-like operations
print(params['weights'].shape) # (10, 5)
print(list(params.keys())) # ['weights', 'bias', 'scale']
print(len(params)) # 3
# Works with dm-tree
import tree
flat_params = tree.flatten(params)@chex.dataclass
class LayerConfig:
input_dim: int
output_dim: int
activation: str
@chex.dataclass
class ModelConfig:
encoder: LayerConfig
decoder: LayerConfig
dropout_rate: float
# Create nested structure
model_config = ModelConfig(
encoder=LayerConfig(input_dim=784, output_dim=128, activation='relu'),
decoder=LayerConfig(input_dim=128, output_dim=10, activation='softmax'),
dropout_rate=0.1
)
# Works seamlessly with JAX transformations
def init_model(config, key):
# Initialize model parameters based on config
encoder_key, decoder_key = jax.random.split(key)
encoder_weights = jax.random.normal(
encoder_key, (config.encoder.input_dim, config.encoder.output_dim)
)
decoder_weights = jax.random.normal(
decoder_key, (config.decoder.input_dim, config.decoder.output_dim)
)
return {
'encoder': encoder_weights,
'decoder': decoder_weights,
'config': config
}
# Can vectorize over configs
init_fn = jax.vmap(init_model, in_axes=(None, 0))
keys = jax.random.split(jax.random.PRNGKey(42), 5)
models = init_fn(model_config, keys)@chex.dataclass
class TrainingState:
params: dict
optimizer_state: dict
step: int
rng_key: jnp.ndarray
def update_step(state, batch):
# Training step that updates the entire state
new_params = update_params(state.params, batch)
new_opt_state = update_optimizer(state.optimizer_state, batch)
new_key, _ = jax.random.split(state.rng_key)
return state.replace(
params=new_params,
optimizer_state=new_opt_state,
step=state.step + 1,
rng_key=new_key
)
# Works with jit compilation
jitted_update = jax.jit(update_step)
# Works with scan for training loops
def train_loop(state, batches):
final_state, _ = jax.lax.scan(
lambda s, batch: (jitted_update(s, batch), None),
state,
batches
)
return final_stateimport dataclasses
# Create dataclass with standard library
@dataclasses.dataclass
class StandardConfig:
value: float
# Manually register for JAX compatibility
chex.register_dataclass_type_with_jax_tree_util(StandardConfig)
# Now works with JAX transformations
config = StandardConfig(value=1.0)
jax.tree_map(lambda x: x * 2, config) # StandardConfig(value=2.0)All chex dataclasses are automatically registered as JAX pytrees, enabling:
jax.tree_map, jax.tree_flatten, etc.Dataclasses support all standard field operations:
replace() method for creating modified copiesFrozen dataclasses provide:
Mappable dataclasses provide:
instance['key'])keys(), values(), items())@chex.dataclass
class Config:
learning_rate: float # Clear type information
layers: List[int] # Supports complex types
activation: str = 'relu' # Default values@chex.dataclass(frozen=True)
class Hyperparameters:
lr: float
batch_size: int
# Immutable configuration@chex.mappable_dataclass
@chex.dataclass
class Parameters:
weights: jnp.ndarray
bias: jnp.ndarray
# Enables params['weights'] accessInstall with Tessl CLI
npx tessl i tessl/pypi-chexevals
scenario-1
scenario-2
scenario-3
scenario-4
scenario-5
scenario-6
scenario-7
scenario-8
scenario-9
scenario-10