or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

dataclasses.mddocs/

0

# JAX-Compatible Dataclasses

1

2

JAX-compatible dataclass implementation that works seamlessly with JAX transformations and pytree operations. Chex dataclasses are automatically registered as JAX pytrees and can be used with all JAX transformations like jit, vmap, grad, etc.

3

4

## Capabilities

5

6

### Core Dataclass Functionality

7

8

JAX-compatible dataclass decorator that creates structured data containers working with JAX ecosystem.

9

10

```python { .api }

11

def dataclass(

12

cls=None,

13

*,

14

init=True,

15

repr=True,

16

eq=True,

17

order=False,

18

unsafe_hash=False,

19

frozen=False,

20

mappable=False

21

):

22

"""

23

JAX-compatible dataclass decorator.

24

25

Parameters:

26

- cls: Class to decorate (when used without parentheses)

27

- init: Generate __init__ method

28

- repr: Generate __repr__ method

29

- eq: Generate __eq__ method

30

- order: Generate ordering methods (__lt__, __le__, __gt__, __ge__)

31

- unsafe_hash: Generate __hash__ method (use with caution)

32

- frozen: Make instances immutable

33

- mappable: Make dataclass compatible with collections.abc.Mapping

34

35

Returns:

36

- Decorated class registered as JAX pytree

37

"""

38

39

def mappable_dataclass(cls):

40

"""

41

Make dataclass compatible with collections.abc.Mapping interface.

42

43

Allows dataclass instances to be used with dm-tree library and provides

44

dict-like access patterns. Changes constructor to dict-style (no positional args).

45

46

Parameters:

47

- cls: A dataclass to make mappable

48

49

Returns:

50

- Modified dataclass implementing collections.abc.Mapping

51

52

Raises:

53

- ValueError: If cls is not a dataclass

54

"""

55

56

def register_dataclass_type_with_jax_tree_util(cls):

57

"""

58

Manually register a dataclass type with JAX tree utilities.

59

60

Normally done automatically by @chex.dataclass, but can be called

61

manually for dataclasses created with other decorators.

62

63

Parameters:

64

- cls: Dataclass type to register

65

"""

66

```

67

68

### Dataclass Exceptions

69

70

Exception types for dataclass operations.

71

72

```python { .api }

73

FrozenInstanceError = dataclasses.FrozenInstanceError

74

```

75

76

## Usage Examples

77

78

### Basic Dataclass Usage

79

80

```python

81

import chex

82

import jax

83

import jax.numpy as jnp

84

85

@chex.dataclass

86

class Config:

87

learning_rate: float

88

batch_size: int

89

hidden_dims: tuple

90

91

# Create instance

92

config = Config(

93

learning_rate=0.01,

94

batch_size=32,

95

hidden_dims=(128, 64)

96

)

97

98

# Works with JAX transformations

99

def compute_loss(config, data):

100

# Use config parameters in computation

101

return jnp.sum(data) * config.learning_rate

102

103

# Can be passed through jit, vmap, etc.

104

jitted_loss = jax.jit(compute_loss)

105

result = jitted_loss(config, jnp.array([1.0, 2.0, 3.0]))

106

```

107

108

### Frozen Dataclasses

109

110

```python

111

@chex.dataclass(frozen=True)

112

class ImmutableConfig:

113

model_name: str

114

version: int

115

116

config = ImmutableConfig(model_name="transformer", version=1)

117

118

# This would raise FrozenInstanceError

119

# config.version = 2 # Error!

120

121

# Use replace() to create modified copies

122

new_config = config.replace(version=2)

123

```

124

125

### Mappable Dataclasses

126

127

```python

128

@chex.mappable_dataclass

129

@chex.dataclass

130

class Parameters:

131

weights: jnp.ndarray

132

bias: jnp.ndarray

133

scale: float = 1.0

134

135

# Can be created dict-style (no positional args)

136

params = Parameters({

137

'weights': jnp.ones((10, 5)),

138

'bias': jnp.zeros(5),

139

'scale': 0.5

140

})

141

142

# Supports dict-like operations

143

print(params['weights'].shape) # (10, 5)

144

print(list(params.keys())) # ['weights', 'bias', 'scale']

145

print(len(params)) # 3

146

147

# Works with dm-tree

148

import tree

149

flat_params = tree.flatten(params)

150

```

151

152

### Nested Dataclasses

153

154

```python

155

@chex.dataclass

156

class LayerConfig:

157

input_dim: int

158

output_dim: int

159

activation: str

160

161

@chex.dataclass

162

class ModelConfig:

163

encoder: LayerConfig

164

decoder: LayerConfig

165

dropout_rate: float

166

167

# Create nested structure

168

model_config = ModelConfig(

169

encoder=LayerConfig(input_dim=784, output_dim=128, activation='relu'),

170

decoder=LayerConfig(input_dim=128, output_dim=10, activation='softmax'),

171

dropout_rate=0.1

172

)

173

174

# Works seamlessly with JAX transformations

175

def init_model(config, key):

176

# Initialize model parameters based on config

177

encoder_key, decoder_key = jax.random.split(key)

178

179

encoder_weights = jax.random.normal(

180

encoder_key, (config.encoder.input_dim, config.encoder.output_dim)

181

)

182

decoder_weights = jax.random.normal(

183

decoder_key, (config.decoder.input_dim, config.decoder.output_dim)

184

)

185

186

return {

187

'encoder': encoder_weights,

188

'decoder': decoder_weights,

189

'config': config

190

}

191

192

# Can vectorize over configs

193

init_fn = jax.vmap(init_model, in_axes=(None, 0))

194

keys = jax.random.split(jax.random.PRNGKey(42), 5)

195

models = init_fn(model_config, keys)

196

```

197

198

### Integration with JAX Transformations

199

200

```python

201

@chex.dataclass

202

class TrainingState:

203

params: dict

204

optimizer_state: dict

205

step: int

206

rng_key: jnp.ndarray

207

208

def update_step(state, batch):

209

# Training step that updates the entire state

210

new_params = update_params(state.params, batch)

211

new_opt_state = update_optimizer(state.optimizer_state, batch)

212

new_key, _ = jax.random.split(state.rng_key)

213

214

return state.replace(

215

params=new_params,

216

optimizer_state=new_opt_state,

217

step=state.step + 1,

218

rng_key=new_key

219

)

220

221

# Works with jit compilation

222

jitted_update = jax.jit(update_step)

223

224

# Works with scan for training loops

225

def train_loop(state, batches):

226

final_state, _ = jax.lax.scan(

227

lambda s, batch: (jitted_update(s, batch), None),

228

state,

229

batches

230

)

231

return final_state

232

```

233

234

### Manual Registration

235

236

```python

237

import dataclasses

238

239

# Create dataclass with standard library

240

@dataclasses.dataclass

241

class StandardConfig:

242

value: float

243

244

# Manually register for JAX compatibility

245

chex.register_dataclass_type_with_jax_tree_util(StandardConfig)

246

247

# Now works with JAX transformations

248

config = StandardConfig(value=1.0)

249

jax.tree_map(lambda x: x * 2, config) # StandardConfig(value=2.0)

250

```

251

252

## Key Features

253

254

### Automatic PyTree Registration

255

256

All chex dataclasses are automatically registered as JAX pytrees, enabling:

257

- Seamless integration with `jax.tree_map`, `jax.tree_flatten`, etc.

258

- Support for all JAX transformations (jit, vmap, grad, scan, etc.)

259

- Compatibility with gradient computation and optimization libraries

260

261

### Field Operations

262

263

Dataclasses support all standard field operations:

264

- `replace()` method for creating modified copies

265

- Field access and introspection

266

- Default values and factory functions

267

- Type hints and validation

268

269

### Immutability Support

270

271

Frozen dataclasses provide:

272

- Immutable instances that can't be modified after creation

273

- Safe sharing across transformations

274

- Clear semantics for functional programming patterns

275

276

### Mapping Interface

277

278

Mappable dataclasses provide:

279

- Dict-style access patterns (`instance['key']`)

280

- Compatibility with dm-tree library

281

- Integration with dictionary-based workflows

282

- Iterator support (`keys()`, `values()`, `items()`)

283

284

## Best Practices

285

286

### Use Type Hints

287

```python

288

@chex.dataclass

289

class Config:

290

learning_rate: float # Clear type information

291

layers: List[int] # Supports complex types

292

activation: str = 'relu' # Default values

293

```

294

295

### Prefer Frozen for Immutable Data

296

```python

297

@chex.dataclass(frozen=True)

298

class Hyperparameters:

299

lr: float

300

batch_size: int

301

# Immutable configuration

302

```

303

304

### Use Mappable for Dict-like Access

305

```python

306

@chex.mappable_dataclass

307

@chex.dataclass

308

class Parameters:

309

weights: jnp.ndarray

310

bias: jnp.ndarray

311

# Enables params['weights'] access

312

```