CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-chex

Comprehensive utilities library for JAX testing, debugging, and instrumentation

73

1.92x
Overview
Eval results
Files

dataclasses.mddocs/

JAX-Compatible Dataclasses

JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations. Chex dataclasses are automatically registered as JAX pytrees and can be used with all JAX transformations like jit, vmap, grad, etc.

Capabilities

Core Dataclass Functionality

JAX-compatible dataclass decorator that creates structured data containers working with JAX ecosystem.

def dataclass(
    cls=None,
    *,
    init=True,
    repr=True,
    eq=True,
    order=False,
    unsafe_hash=False,
    frozen=False,
    mappable=False
):
    """
    JAX-compatible dataclass decorator.
    
    Parameters:
    - cls: Class to decorate (when used without parentheses)
    - init: Generate __init__ method
    - repr: Generate __repr__ method  
    - eq: Generate __eq__ method
    - order: Generate ordering methods (__lt__, __le__, __gt__, __ge__)
    - unsafe_hash: Generate __hash__ method (use with caution)
    - frozen: Make instances immutable
    - mappable: Make dataclass compatible with collections.abc.Mapping
    
    Returns:
    - Decorated class registered as JAX pytree
    """

def mappable_dataclass(cls):
    """
    Make dataclass compatible with collections.abc.Mapping interface.
    
    Allows dataclass instances to be used with dm-tree library and provides
    dict-like access patterns. Changes constructor to dict-style (no positional args).
    
    Parameters:
    - cls: A dataclass to make mappable
    
    Returns:
    - Modified dataclass implementing collections.abc.Mapping
    
    Raises:
    - ValueError: If cls is not a dataclass
    """

def register_dataclass_type_with_jax_tree_util(cls):
    """
    Manually register a dataclass type with JAX tree utilities.
    
    Normally done automatically by @chex.dataclass, but can be called
    manually for dataclasses created with other decorators.
    
    Parameters:
    - cls: Dataclass type to register
    """

Dataclass Exceptions

Exception types for dataclass operations.

FrozenInstanceError = dataclasses.FrozenInstanceError

Usage Examples

Basic Dataclass Usage

import chex
import jax
import jax.numpy as jnp

@chex.dataclass
class Config:
    learning_rate: float
    batch_size: int
    hidden_dims: tuple
    
# Create instance
config = Config(
    learning_rate=0.01,
    batch_size=32,
    hidden_dims=(128, 64)
)

# Works with JAX transformations
def compute_loss(config, data):
    # Use config parameters in computation
    return jnp.sum(data) * config.learning_rate

# Can be passed through jit, vmap, etc.
jitted_loss = jax.jit(compute_loss)
result = jitted_loss(config, jnp.array([1.0, 2.0, 3.0]))

Frozen Dataclasses

@chex.dataclass(frozen=True)
class ImmutableConfig:
    model_name: str
    version: int
    
config = ImmutableConfig(model_name="transformer", version=1)

# This would raise FrozenInstanceError
# config.version = 2  # Error!

# Use replace() to create modified copies
new_config = config.replace(version=2)

Mappable Dataclasses

@chex.mappable_dataclass
@chex.dataclass
class Parameters:
    weights: jnp.ndarray
    bias: jnp.ndarray
    scale: float = 1.0

# Can be created dict-style (no positional args)
params = Parameters({
    'weights': jnp.ones((10, 5)),
    'bias': jnp.zeros(5),
    'scale': 0.5
})

# Supports dict-like operations
print(params['weights'].shape)  # (10, 5)
print(list(params.keys()))      # ['weights', 'bias', 'scale']
print(len(params))              # 3

# Works with dm-tree
import tree
flat_params = tree.flatten(params)

Nested Dataclasses

@chex.dataclass
class LayerConfig:
    input_dim: int
    output_dim: int
    activation: str

@chex.dataclass  
class ModelConfig:
    encoder: LayerConfig
    decoder: LayerConfig
    dropout_rate: float

# Create nested structure
model_config = ModelConfig(
    encoder=LayerConfig(input_dim=784, output_dim=128, activation='relu'),
    decoder=LayerConfig(input_dim=128, output_dim=10, activation='softmax'),
    dropout_rate=0.1
)

# Works seamlessly with JAX transformations
def init_model(config, key):
    # Initialize model parameters based on config
    encoder_key, decoder_key = jax.random.split(key)
    
    encoder_weights = jax.random.normal(
        encoder_key, (config.encoder.input_dim, config.encoder.output_dim)
    )
    decoder_weights = jax.random.normal(
        decoder_key, (config.decoder.input_dim, config.decoder.output_dim)
    )
    
    return {
        'encoder': encoder_weights,
        'decoder': decoder_weights,
        'config': config
    }

# Can vectorize over configs
init_fn = jax.vmap(init_model, in_axes=(None, 0))
keys = jax.random.split(jax.random.PRNGKey(42), 5)
models = init_fn(model_config, keys)

Integration with JAX Transformations

@chex.dataclass
class TrainingState:
    params: dict
    optimizer_state: dict
    step: int
    rng_key: jnp.ndarray

def update_step(state, batch):
    # Training step that updates the entire state
    new_params = update_params(state.params, batch)
    new_opt_state = update_optimizer(state.optimizer_state, batch)
    new_key, _ = jax.random.split(state.rng_key)
    
    return state.replace(
        params=new_params,
        optimizer_state=new_opt_state,
        step=state.step + 1,
        rng_key=new_key
    )

# Works with jit compilation
jitted_update = jax.jit(update_step)

# Works with scan for training loops
def train_loop(state, batches):
    final_state, _ = jax.lax.scan(
        lambda s, batch: (jitted_update(s, batch), None),
        state,
        batches
    )
    return final_state

Manual Registration

import dataclasses

# Create dataclass with standard library
@dataclasses.dataclass
class StandardConfig:
    value: float

# Manually register for JAX compatibility
chex.register_dataclass_type_with_jax_tree_util(StandardConfig)

# Now works with JAX transformations
config = StandardConfig(value=1.0)
jax.tree_map(lambda x: x * 2, config)  # StandardConfig(value=2.0)

Key Features

Automatic PyTree Registration

All chex dataclasses are automatically registered as JAX pytrees, enabling:

  • Seamless integration with jax.tree_map, jax.tree_flatten, etc.
  • Support for all JAX transformations (jit, vmap, grad, scan, etc.)
  • Compatibility with gradient computation and optimization libraries

Field Operations

Dataclasses support all standard field operations:

  • replace() method for creating modified copies
  • Field access and introspection
  • Default values and factory functions
  • Type hints and validation

Immutability Support

Frozen dataclasses provide:

  • Immutable instances that can't be modified after creation
  • Safe sharing across transformations
  • Clear semantics for functional programming patterns

Mapping Interface

Mappable dataclasses provide:

  • Dict-style access patterns (instance['key'])
  • Compatibility with dm-tree library
  • Integration with dictionary-based workflows
  • Iterator support (keys(), values(), items())

Best Practices

Use Type Hints

@chex.dataclass
class Config:
    learning_rate: float  # Clear type information
    layers: List[int]     # Supports complex types
    activation: str = 'relu'  # Default values

Prefer Frozen for Immutable Data

@chex.dataclass(frozen=True)
class Hyperparameters:
    lr: float
    batch_size: int
    # Immutable configuration

Use Mappable for Dict-like Access

@chex.mappable_dataclass
@chex.dataclass
class Parameters:
    weights: jnp.ndarray
    bias: jnp.ndarray
    # Enables params['weights'] access

Install with Tessl CLI

npx tessl i tessl/pypi-chex

docs

advanced.md

assertions.md

dataclasses.md

debugging.md

index.md

testing.md

types.md

tile.json