0
# JAX-Compatible Dataclasses
1
2
JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations. Chex dataclasses are automatically registered as JAX pytrees and can be used with all JAX transformations like jit, vmap, grad, etc.
3
4
## Capabilities
5
6
### Core Dataclass Functionality
7
8
JAX-compatible dataclass decorator that creates structured data containers working with JAX ecosystem.
9
10
```python { .api }
11
def dataclass(
12
cls=None,
13
*,
14
init=True,
15
repr=True,
16
eq=True,
17
order=False,
18
unsafe_hash=False,
19
frozen=False,
20
mappable=False
21
):
22
"""
23
JAX-compatible dataclass decorator.
24
25
Parameters:
26
- cls: Class to decorate (when used without parentheses)
27
- init: Generate __init__ method
28
- repr: Generate __repr__ method
29
- eq: Generate __eq__ method
30
- order: Generate ordering methods (__lt__, __le__, __gt__, __ge__)
31
- unsafe_hash: Generate __hash__ method (use with caution)
32
- frozen: Make instances immutable
33
- mappable: Make dataclass compatible with collections.abc.Mapping
34
35
Returns:
36
- Decorated class registered as JAX pytree
37
"""
38
39
def mappable_dataclass(cls):
40
"""
41
Make dataclass compatible with collections.abc.Mapping interface.
42
43
Allows dataclass instances to be used with dm-tree library and provides
44
dict-like access patterns. Changes constructor to dict-style (no positional args).
45
46
Parameters:
47
- cls: A dataclass to make mappable
48
49
Returns:
50
- Modified dataclass implementing collections.abc.Mapping
51
52
Raises:
53
- ValueError: If cls is not a dataclass
54
"""
55
56
def register_dataclass_type_with_jax_tree_util(cls):
57
"""
58
Manually register a dataclass type with JAX tree utilities.
59
60
Normally done automatically by @chex.dataclass, but can be called
61
manually for dataclasses created with other decorators.
62
63
Parameters:
64
- cls: Dataclass type to register
65
"""
66
```
67
68
### Dataclass Exceptions
69
70
Exception types for dataclass operations.
71
72
```python { .api }
73
FrozenInstanceError = dataclasses.FrozenInstanceError
74
```
75
76
## Usage Examples
77
78
### Basic Dataclass Usage
79
80
```python
81
import chex
82
import jax
83
import jax.numpy as jnp
84
85
@chex.dataclass
86
class Config:
87
learning_rate: float
88
batch_size: int
89
hidden_dims: tuple
90
91
# Create instance
92
config = Config(
93
learning_rate=0.01,
94
batch_size=32,
95
hidden_dims=(128, 64)
96
)
97
98
# Works with JAX transformations
99
def compute_loss(config, data):
100
# Use config parameters in computation
101
return jnp.sum(data) * config.learning_rate
102
103
# Can be passed through jit, vmap, etc.
104
jitted_loss = jax.jit(compute_loss)
105
result = jitted_loss(config, jnp.array([1.0, 2.0, 3.0]))
106
```
107
108
### Frozen Dataclasses
109
110
```python
111
@chex.dataclass(frozen=True)
112
class ImmutableConfig:
113
model_name: str
114
version: int
115
116
config = ImmutableConfig(model_name="transformer", version=1)
117
118
# This would raise FrozenInstanceError
119
# config.version = 2 # Error!
120
121
# Use replace() to create modified copies
122
new_config = config.replace(version=2)
123
```
124
125
### Mappable Dataclasses
126
127
```python
128
@chex.mappable_dataclass
129
@chex.dataclass
130
class Parameters:
131
weights: jnp.ndarray
132
bias: jnp.ndarray
133
scale: float = 1.0
134
135
# Can be created dict-style (no positional args)
136
params = Parameters({
137
'weights': jnp.ones((10, 5)),
138
'bias': jnp.zeros(5),
139
'scale': 0.5
140
})
141
142
# Supports dict-like operations
143
print(params['weights'].shape) # (10, 5)
144
print(list(params.keys())) # ['weights', 'bias', 'scale']
145
print(len(params)) # 3
146
147
# Works with dm-tree
148
import tree
149
flat_params = tree.flatten(params)
150
```
151
152
### Nested Dataclasses
153
154
```python
155
@chex.dataclass
156
class LayerConfig:
157
input_dim: int
158
output_dim: int
159
activation: str
160
161
@chex.dataclass
162
class ModelConfig:
163
encoder: LayerConfig
164
decoder: LayerConfig
165
dropout_rate: float
166
167
# Create nested structure
168
model_config = ModelConfig(
169
encoder=LayerConfig(input_dim=784, output_dim=128, activation='relu'),
170
decoder=LayerConfig(input_dim=128, output_dim=10, activation='softmax'),
171
dropout_rate=0.1
172
)
173
174
# Works seamlessly with JAX transformations
175
def init_model(config, key):
176
# Initialize model parameters based on config
177
encoder_key, decoder_key = jax.random.split(key)
178
179
encoder_weights = jax.random.normal(
180
encoder_key, (config.encoder.input_dim, config.encoder.output_dim)
181
)
182
decoder_weights = jax.random.normal(
183
decoder_key, (config.decoder.input_dim, config.decoder.output_dim)
184
)
185
186
return {
187
'encoder': encoder_weights,
188
'decoder': decoder_weights,
189
'config': config
190
}
191
192
# Can vectorize over configs
193
init_fn = jax.vmap(init_model, in_axes=(None, 0))
194
keys = jax.random.split(jax.random.PRNGKey(42), 5)
195
models = init_fn(model_config, keys)
196
```
197
198
### Integration with JAX Transformations
199
200
```python
201
@chex.dataclass
202
class TrainingState:
203
params: dict
204
optimizer_state: dict
205
step: int
206
rng_key: jnp.ndarray
207
208
def update_step(state, batch):
209
# Training step that updates the entire state
210
new_params = update_params(state.params, batch)
211
new_opt_state = update_optimizer(state.optimizer_state, batch)
212
new_key, _ = jax.random.split(state.rng_key)
213
214
return state.replace(
215
params=new_params,
216
optimizer_state=new_opt_state,
217
step=state.step + 1,
218
rng_key=new_key
219
)
220
221
# Works with jit compilation
222
jitted_update = jax.jit(update_step)
223
224
# Works with scan for training loops
225
def train_loop(state, batches):
226
final_state, _ = jax.lax.scan(
227
lambda s, batch: (jitted_update(s, batch), None),
228
state,
229
batches
230
)
231
return final_state
232
```
233
234
### Manual Registration
235
236
```python
237
import dataclasses
238
239
# Create dataclass with standard library
240
@dataclasses.dataclass
241
class StandardConfig:
242
value: float
243
244
# Manually register for JAX compatibility
245
chex.register_dataclass_type_with_jax_tree_util(StandardConfig)
246
247
# Now works with JAX transformations
248
config = StandardConfig(value=1.0)
249
jax.tree_map(lambda x: x * 2, config) # StandardConfig(value=2.0)
250
```
251
252
## Key Features
253
254
### Automatic PyTree Registration
255
256
All chex dataclasses are automatically registered as JAX pytrees, enabling:
257
- Seamless integration with `jax.tree_map`, `jax.tree_flatten`, etc.
258
- Support for all JAX transformations (jit, vmap, grad, scan, etc.)
259
- Compatibility with gradient computation and optimization libraries
260
261
### Field Operations
262
263
Dataclasses support all standard field operations:
264
- `replace()` method for creating modified copies
265
- Field access and introspection
266
- Default values and factory functions
267
- Type hints and validation
268
269
### Immutability Support
270
271
Frozen dataclasses provide:
272
- Immutable instances that can't be modified after creation
273
- Safe sharing across transformations
274
- Clear semantics for functional programming patterns
275
276
### Mapping Interface
277
278
Mappable dataclasses provide:
279
- Dict-style access patterns (`instance['key']`)
280
- Compatibility with dm-tree library
281
- Integration with dictionary-based workflows
282
- Iterator support (`keys()`, `values()`, `items()`)
283
284
## Best Practices
285
286
### Use Type Hints
287
```python
288
@chex.dataclass
289
class Config:
290
learning_rate: float # Clear type information
291
layers: List[int] # Supports complex types
292
activation: str = 'relu' # Default values
293
```
294
295
### Prefer Frozen for Immutable Data
296
```python
297
@chex.dataclass(frozen=True)
298
class Hyperparameters:
299
lr: float
300
batch_size: int
301
# Immutable configuration
302
```
303
304
### Use Mappable for Dict-like Access
305
```python
306
@chex.mappable_dataclass
307
@chex.dataclass
308
class Parameters:
309
weights: jnp.ndarray
310
bias: jnp.ndarray
311
# Enables params['weights'] access
312
```