or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

types.mddocs/

0

# Type Definitions

1

2

Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.

3

4

## Capabilities

5

6

### Array Types

7

8

Core array type definitions for JAX and NumPy arrays.

9

10

```python { .api }

11

# Base array types

12

ArrayNumpy = np.ndarray

13

ArrayDevice = jax.Array

14

ArraySharded = jax.Array # Backward compatibility alias

15

ArrayBatched = jax.Array # Backward compatibility alias

16

17

# Generic array type combining JAX and NumPy arrays

18

Array = Union[

19

ArrayDevice,

20

ArrayBatched,

21

ArraySharded,

22

ArrayNumpy,

23

np.bool_,

24

np.number

25

]

26

```

27

28

### Tree Types

29

30

Type definitions for JAX pytrees containing arrays.

31

32

```python { .api }

33

# Tree of generic arrays

34

ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]

35

36

# Tree of JAX device arrays

37

ArrayDeviceTree = Union[

38

ArrayDevice,

39

Iterable['ArrayDeviceTree'],

40

Mapping[Any, 'ArrayDeviceTree']

41

]

42

43

# Tree of NumPy arrays

44

ArrayNumpyTree = Union[

45

ArrayNumpy,

46

Iterable['ArrayNumpyTree'],

47

Mapping[Any, 'ArrayNumpyTree']

48

]

49

```

50

51

### Scalar and Numeric Types

52

53

Type definitions for scalar values and numeric data.

54

55

```python { .api }

56

# Scalar types

57

Scalar = Union[float, int]

58

59

# Combined numeric type including arrays and scalars

60

Numeric = Union[Array, Scalar]

61

```

62

63

### Shape and Structure Types

64

65

Type definitions for array shapes and JAX structures.

66

67

```python { .api }

68

# Shape type allowing flexible dimension specifications

69

Shape = Sequence[Union[int, Any]]

70

71

# JAX pytree definition type

72

PyTreeDef = jax.tree_util.PyTreeDef

73

```

74

75

### Device and Hardware Types

76

77

Type definitions for JAX devices and hardware.

78

79

```python { .api }

80

# JAX device type

81

Device = jax.Device

82

83

# PRNG key type for random number generation

84

PRNGKey = jax.Array

85

```

86

87

### Data Type Definitions

88

89

Type definitions for array data types.

90

91

```python { .api }

92

# Array dtype type (version-dependent)

93

ArrayDType = jax.typing.DTypeLike # JAX 0.4.19+

94

# ArrayDType = Any # Older JAX versions

95

```

96

97

## Usage Examples

98

99

### Type Annotations

100

101

```python

102

import chex

103

import jax

104

import jax.numpy as jnp

105

from typing import Tuple, Optional

106

107

def process_batch(

108

data: chex.Array,

109

weights: chex.ArrayTree,

110

batch_size: int

111

) -> chex.Array:

112

"""Process a batch of data with given weights."""

113

chex.assert_shape(data, (batch_size, None)) # Flexible feature dimension

114

return jnp.dot(data, weights['linear']) + weights['bias']

115

116

def compute_loss(

117

predictions: chex.Array,

118

targets: chex.Array

119

) -> chex.Scalar:

120

"""Compute scalar loss value."""

121

return jnp.mean((predictions - targets) ** 2)

122

123

def create_model_state(

124

params: chex.ArrayTree,

125

optimizer_state: chex.ArrayTree,

126

step: int,

127

rng_key: chex.PRNGKey

128

) -> dict:

129

"""Create training state with proper types."""

130

return {

131

'params': params,

132

'opt_state': optimizer_state,

133

'step': step,

134

'rng': rng_key

135

}

136

```

137

138

### Shape Specifications

139

140

```python

141

from typing import Callable

142

143

def linear_layer(

144

inputs: chex.Array, # Shape: (batch, input_dim)

145

weights: chex.Array, # Shape: (input_dim, output_dim)

146

bias: chex.Array # Shape: (output_dim,)

147

) -> chex.Array: # Shape: (batch, output_dim)

148

"""Linear transformation layer."""

149

chex.assert_rank(inputs, 2)

150

chex.assert_rank(weights, 2)

151

chex.assert_rank(bias, 1)

152

153

return jnp.dot(inputs, weights) + bias

154

155

# Flexible shape specifications

156

def process_sequence(

157

sequence: chex.Array, # Shape: (seq_len, batch, features)

158

mask: Optional[chex.Array] = None # Shape: (seq_len, batch) or None

159

) -> chex.Array: # Shape: (batch, features)

160

"""Process variable-length sequences."""

161

seq_len, batch_size, features = sequence.shape

162

163

if mask is not None:

164

chex.assert_shape(mask, (seq_len, batch_size))

165

sequence = sequence * mask[..., None]

166

167

return jnp.mean(sequence, axis=0) # Average over sequence length

168

```

169

170

### Tree Type Usage

171

172

```python

173

def initialize_model(

174

key: chex.PRNGKey,

175

input_shape: chex.Shape

176

) -> chex.ArrayTree:

177

"""Initialize model parameters as a tree structure."""

178

179

keys = jax.random.split(key, 3)

180

181

params = {

182

'encoder': {

183

'weights': jax.random.normal(keys[0], (input_shape[-1], 128)),

184

'bias': jnp.zeros(128)

185

},

186

'decoder': {

187

'weights': jax.random.normal(keys[1], (128, 10)),

188

'bias': jnp.zeros(10)

189

},

190

'scale': jax.random.uniform(keys[2], (), minval=0.5, maxval=1.5)

191

}

192

193

return params

194

195

def apply_model(

196

params: chex.ArrayTree,

197

inputs: chex.Array

198

) -> chex.Array:

199

"""Apply model with tree-structured parameters."""

200

201

# Encoder

202

hidden = jnp.dot(inputs, params['encoder']['weights'])

203

hidden = hidden + params['encoder']['bias']

204

hidden = jax.nn.relu(hidden)

205

206

# Decoder

207

outputs = jnp.dot(hidden, params['decoder']['weights'])

208

outputs = outputs + params['decoder']['bias']

209

210

# Apply global scale

211

outputs = outputs * params['scale']

212

213

return outputs

214

215

def tree_statistics(tree: chex.ArrayTree) -> dict:

216

"""Compute statistics over a tree of arrays."""

217

218

def compute_stats(array: chex.Array) -> dict:

219

return {

220

'mean': jnp.mean(array),

221

'std': jnp.std(array),

222

'shape': array.shape

223

}

224

225

return jax.tree_map(compute_stats, tree)

226

```

227

228

### Device Type Usage

229

230

```python

231

def distribute_computation(

232

data: chex.Array,

233

devices: list[chex.Device]

234

) -> chex.Array:

235

"""Distribute computation across multiple devices."""

236

237

n_devices = len(devices)

238

batch_size = data.shape[0]

239

240

# Ensure data can be evenly split

241

chex.assert_is_divisible(batch_size, n_devices)

242

243

# Split data across devices

244

per_device_size = batch_size // n_devices

245

split_data = data.reshape(n_devices, per_device_size, *data.shape[1:])

246

247

# Process on each device

248

def process_shard(shard):

249

return jnp.sum(shard, axis=0)

250

251

# Map across devices

252

results = jax.pmap(process_shard)(split_data)

253

254

return results

255

256

def check_device_placement(

257

array: chex.Array,

258

expected_device: chex.Device

259

) -> bool:

260

"""Check if array is placed on expected device."""

261

if hasattr(array, 'device'):

262

return array.device == expected_device

263

return True # NumPy arrays don't have device placement

264

```

265

266

### Numeric Type Usage

267

268

```python

269

def safe_divide(

270

numerator: chex.Numeric,

271

denominator: chex.Numeric,

272

epsilon: float = 1e-8

273

) -> chex.Numeric:

274

"""Safely divide numeric values with epsilon."""

275

276

# Handle both scalar and array inputs

277

if isinstance(denominator, (int, float)):

278

safe_denom = denominator + epsilon if denominator == 0 else denominator

279

else:

280

safe_denom = jnp.where(

281

jnp.abs(denominator) < epsilon,

282

epsilon,

283

denominator

284

)

285

286

return numerator / safe_denom

287

288

def normalize_features(

289

features: chex.Array,

290

axis: Optional[int] = None

291

) -> Tuple[chex.Array, chex.Scalar]:

292

"""Normalize features and return normalization constant."""

293

294

# Compute normalization factor

295

norm: chex.Scalar = jnp.linalg.norm(features, axis=axis, keepdims=True)

296

297

# Normalize

298

normalized = safe_divide(features, norm)

299

300

return normalized, jnp.squeeze(norm)

301

```

302

303

### Generic Type Functions

304

305

```python

306

from typing import TypeVar, Callable

307

308

T = TypeVar('T', bound=chex.ArrayTree)

309

310

def apply_tree_function(

311

tree: T,

312

fn: Callable[[chex.Array], chex.Array]

313

) -> T:

314

"""Apply function to all arrays in tree, preserving structure."""

315

return jax.tree_map(fn, tree)

316

317

def validate_tree_structure(

318

tree1: chex.ArrayTree,

319

tree2: chex.ArrayTree

320

) -> bool:

321

"""Validate that two trees have the same structure."""

322

try:

323

jax.tree_map(lambda x, y: None, tree1, tree2)

324

return True

325

except (TypeError, ValueError):

326

return False

327

328

def convert_tree_dtype(

329

tree: chex.ArrayTree,

330

dtype: chex.ArrayDType

331

) -> chex.ArrayTree:

332

"""Convert all arrays in tree to specified dtype."""

333

return jax.tree_map(lambda x: x.astype(dtype), tree)

334

```

335

336

## Type Compatibility

337

338

### JAX Integration

339

All Chex types are designed for seamless integration with JAX:

340

- Array types work with all JAX transformations

341

- Tree types support JAX pytree operations

342

- Shape types enable flexible dimension handling

343

- Device types support multi-device computation

344

345

### NumPy Compatibility

346

Chex types maintain NumPy compatibility:

347

- Array types include NumPy arrays

348

- Scalar types work with NumPy operations

349

- Shape specifications support NumPy broadcasting

350

351

### Version Compatibility

352

Type definitions adapt to JAX version differences:

353

- ArrayDType uses JAX's DTypeLike when available

354

- Backward compatibility aliases for deprecated types

355

- Future-proof type specifications

356

357

## Best Practices

358

359

### Use Specific Types

360

```python

361

# Good: Specific type information

362

def process_images(images: chex.Array) -> chex.Array:

363

chex.assert_rank(images, 4) # (batch, height, width, channels)

364

return images

365

366

# Better: Include shape information in docstring

367

def process_images(images: chex.Array) -> chex.Array:

368

"""Process batch of images.

369

370

Args:

371

images: Array of shape (batch, height, width, channels)

372

373

Returns:

374

Processed images of same shape

375

"""

376

```

377

378

### Combine with Assertions

379

```python

380

def typed_function(

381

data: chex.Array,

382

weights: chex.ArrayTree

383

) -> chex.Array:

384

# Runtime validation matches type annotations

385

chex.assert_type(data, chex.Array)

386

chex.assert_tree_has_only_ndarrays(weights)

387

388

return process_data(data, weights)

389

```

390

391

### Document Shape Expectations

392

```python

393

def attention_layer(

394

query: chex.Array, # (batch, seq_q, dim)

395

key: chex.Array, # (batch, seq_k, dim)

396

value: chex.Array # (batch, seq_k, dim)

397

) -> chex.Array: # (batch, seq_q, dim)

398

"""Multi-head attention with clear shape specifications."""

399

pass

400

```