0
# Type Definitions
1
2
Complete type system for JAX development including array types, shape specifications, device types, and other computational primitives used throughout the JAX ecosystem.
3
4
## Capabilities
5
6
### Array Types
7
8
Core array type definitions for JAX and NumPy arrays.
9
10
```python { .api }
11
# Base array types
12
ArrayNumpy = np.ndarray
13
ArrayDevice = jax.Array
14
ArraySharded = jax.Array # Backward compatibility alias
15
ArrayBatched = jax.Array # Backward compatibility alias
16
17
# Generic array type combining JAX and NumPy arrays
18
Array = Union[
19
ArrayDevice,
20
ArrayBatched,
21
ArraySharded,
22
ArrayNumpy,
23
np.bool_,
24
np.number
25
]
26
```
27
28
### Tree Types
29
30
Type definitions for JAX pytrees containing arrays.
31
32
```python { .api }
33
# Tree of generic arrays
34
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
35
36
# Tree of JAX device arrays
37
ArrayDeviceTree = Union[
38
ArrayDevice,
39
Iterable['ArrayDeviceTree'],
40
Mapping[Any, 'ArrayDeviceTree']
41
]
42
43
# Tree of NumPy arrays
44
ArrayNumpyTree = Union[
45
ArrayNumpy,
46
Iterable['ArrayNumpyTree'],
47
Mapping[Any, 'ArrayNumpyTree']
48
]
49
```
50
51
### Scalar and Numeric Types
52
53
Type definitions for scalar values and numeric data.
54
55
```python { .api }
56
# Scalar types
57
Scalar = Union[float, int]
58
59
# Combined numeric type including arrays and scalars
60
Numeric = Union[Array, Scalar]
61
```
62
63
### Shape and Structure Types
64
65
Type definitions for array shapes and JAX structures.
66
67
```python { .api }
68
# Shape type allowing flexible dimension specifications
69
Shape = Sequence[Union[int, Any]]
70
71
# JAX pytree definition type
72
PyTreeDef = jax.tree_util.PyTreeDef
73
```
74
75
### Device and Hardware Types
76
77
Type definitions for JAX devices and hardware.
78
79
```python { .api }
80
# JAX device type
81
Device = jax.Device
82
83
# PRNG key type for random number generation
84
PRNGKey = jax.Array
85
```
86
87
### Data Type Definitions
88
89
Type definitions for array data types.
90
91
```python { .api }
92
# Array dtype type (version-dependent)
93
ArrayDType = jax.typing.DTypeLike # JAX 0.4.19+
94
# ArrayDType = Any # Older JAX versions
95
```
96
97
## Usage Examples
98
99
### Type Annotations
100
101
```python
102
import chex
103
import jax
104
import jax.numpy as jnp
105
from typing import Tuple, Optional
106
107
def process_batch(
108
data: chex.Array,
109
weights: chex.ArrayTree,
110
batch_size: int
111
) -> chex.Array:
112
"""Process a batch of data with given weights."""
113
chex.assert_shape(data, (batch_size, None)) # Flexible feature dimension
114
return jnp.dot(data, weights['linear']) + weights['bias']
115
116
def compute_loss(
117
predictions: chex.Array,
118
targets: chex.Array
119
) -> chex.Scalar:
120
"""Compute scalar loss value."""
121
return jnp.mean((predictions - targets) ** 2)
122
123
def create_model_state(
124
params: chex.ArrayTree,
125
optimizer_state: chex.ArrayTree,
126
step: int,
127
rng_key: chex.PRNGKey
128
) -> dict:
129
"""Create training state with proper types."""
130
return {
131
'params': params,
132
'opt_state': optimizer_state,
133
'step': step,
134
'rng': rng_key
135
}
136
```
137
138
### Shape Specifications
139
140
```python
141
from typing import Callable
142
143
def linear_layer(
144
inputs: chex.Array, # Shape: (batch, input_dim)
145
weights: chex.Array, # Shape: (input_dim, output_dim)
146
bias: chex.Array # Shape: (output_dim,)
147
) -> chex.Array: # Shape: (batch, output_dim)
148
"""Linear transformation layer."""
149
chex.assert_rank(inputs, 2)
150
chex.assert_rank(weights, 2)
151
chex.assert_rank(bias, 1)
152
153
return jnp.dot(inputs, weights) + bias
154
155
# Flexible shape specifications
156
def process_sequence(
157
sequence: chex.Array, # Shape: (seq_len, batch, features)
158
mask: Optional[chex.Array] = None # Shape: (seq_len, batch) or None
159
) -> chex.Array: # Shape: (batch, features)
160
"""Process variable-length sequences."""
161
seq_len, batch_size, features = sequence.shape
162
163
if mask is not None:
164
chex.assert_shape(mask, (seq_len, batch_size))
165
sequence = sequence * mask[..., None]
166
167
return jnp.mean(sequence, axis=0) # Average over sequence length
168
```
169
170
### Tree Type Usage
171
172
```python
173
def initialize_model(
174
key: chex.PRNGKey,
175
input_shape: chex.Shape
176
) -> chex.ArrayTree:
177
"""Initialize model parameters as a tree structure."""
178
179
keys = jax.random.split(key, 3)
180
181
params = {
182
'encoder': {
183
'weights': jax.random.normal(keys[0], (input_shape[-1], 128)),
184
'bias': jnp.zeros(128)
185
},
186
'decoder': {
187
'weights': jax.random.normal(keys[1], (128, 10)),
188
'bias': jnp.zeros(10)
189
},
190
'scale': jax.random.uniform(keys[2], (), minval=0.5, maxval=1.5)
191
}
192
193
return params
194
195
def apply_model(
196
params: chex.ArrayTree,
197
inputs: chex.Array
198
) -> chex.Array:
199
"""Apply model with tree-structured parameters."""
200
201
# Encoder
202
hidden = jnp.dot(inputs, params['encoder']['weights'])
203
hidden = hidden + params['encoder']['bias']
204
hidden = jax.nn.relu(hidden)
205
206
# Decoder
207
outputs = jnp.dot(hidden, params['decoder']['weights'])
208
outputs = outputs + params['decoder']['bias']
209
210
# Apply global scale
211
outputs = outputs * params['scale']
212
213
return outputs
214
215
def tree_statistics(tree: chex.ArrayTree) -> dict:
216
"""Compute statistics over a tree of arrays."""
217
218
def compute_stats(array: chex.Array) -> dict:
219
return {
220
'mean': jnp.mean(array),
221
'std': jnp.std(array),
222
'shape': array.shape
223
}
224
225
return jax.tree_map(compute_stats, tree)
226
```
227
228
### Device Type Usage
229
230
```python
231
def distribute_computation(
232
data: chex.Array,
233
devices: list[chex.Device]
234
) -> chex.Array:
235
"""Distribute computation across multiple devices."""
236
237
n_devices = len(devices)
238
batch_size = data.shape[0]
239
240
# Ensure data can be evenly split
241
chex.assert_is_divisible(batch_size, n_devices)
242
243
# Split data across devices
244
per_device_size = batch_size // n_devices
245
split_data = data.reshape(n_devices, per_device_size, *data.shape[1:])
246
247
# Process on each device
248
def process_shard(shard):
249
return jnp.sum(shard, axis=0)
250
251
# Map across devices
252
results = jax.pmap(process_shard)(split_data)
253
254
return results
255
256
def check_device_placement(
257
array: chex.Array,
258
expected_device: chex.Device
259
) -> bool:
260
"""Check if array is placed on expected device."""
261
if hasattr(array, 'device'):
262
return array.device == expected_device
263
return True # NumPy arrays don't have device placement
264
```
265
266
### Numeric Type Usage
267
268
```python
269
def safe_divide(
270
numerator: chex.Numeric,
271
denominator: chex.Numeric,
272
epsilon: float = 1e-8
273
) -> chex.Numeric:
274
"""Safely divide numeric values with epsilon."""
275
276
# Handle both scalar and array inputs
277
if isinstance(denominator, (int, float)):
278
safe_denom = denominator + epsilon if denominator == 0 else denominator
279
else:
280
safe_denom = jnp.where(
281
jnp.abs(denominator) < epsilon,
282
epsilon,
283
denominator
284
)
285
286
return numerator / safe_denom
287
288
def normalize_features(
289
features: chex.Array,
290
axis: Optional[int] = None
291
) -> Tuple[chex.Array, chex.Scalar]:
292
"""Normalize features and return normalization constant."""
293
294
# Compute normalization factor
295
norm: chex.Scalar = jnp.linalg.norm(features, axis=axis, keepdims=True)
296
297
# Normalize
298
normalized = safe_divide(features, norm)
299
300
return normalized, jnp.squeeze(norm)
301
```
302
303
### Generic Type Functions
304
305
```python
306
from typing import TypeVar, Callable
307
308
T = TypeVar('T', bound=chex.ArrayTree)
309
310
def apply_tree_function(
311
tree: T,
312
fn: Callable[[chex.Array], chex.Array]
313
) -> T:
314
"""Apply function to all arrays in tree, preserving structure."""
315
return jax.tree_map(fn, tree)
316
317
def validate_tree_structure(
318
tree1: chex.ArrayTree,
319
tree2: chex.ArrayTree
320
) -> bool:
321
"""Validate that two trees have the same structure."""
322
try:
323
jax.tree_map(lambda x, y: None, tree1, tree2)
324
return True
325
except (TypeError, ValueError):
326
return False
327
328
def convert_tree_dtype(
329
tree: chex.ArrayTree,
330
dtype: chex.ArrayDType
331
) -> chex.ArrayTree:
332
"""Convert all arrays in tree to specified dtype."""
333
return jax.tree_map(lambda x: x.astype(dtype), tree)
334
```
335
336
## Type Compatibility
337
338
### JAX Integration
339
All Chex types are designed for seamless integration with JAX:
340
- Array types work with all JAX transformations
341
- Tree types support JAX pytree operations
342
- Shape types enable flexible dimension handling
343
- Device types support multi-device computation
344
345
### NumPy Compatibility
346
Chex types maintain NumPy compatibility:
347
- Array types include NumPy arrays
348
- Scalar types work with NumPy operations
349
- Shape specifications support NumPy broadcasting
350
351
### Version Compatibility
352
Type definitions adapt to JAX version differences:
353
- ArrayDType uses JAX's DTypeLike when available
354
- Backward compatibility aliases for deprecated types
355
- Future-proof type specifications
356
357
## Best Practices
358
359
### Use Specific Types
360
```python
361
# Good: Specific type information
362
def process_images(images: chex.Array) -> chex.Array:
363
chex.assert_rank(images, 4) # (batch, height, width, channels)
364
return images
365
366
# Better: Include shape information in docstring
367
def process_images(images: chex.Array) -> chex.Array:
368
"""Process batch of images.
369
370
Args:
371
images: Array of shape (batch, height, width, channels)
372
373
Returns:
374
Processed images of same shape
375
"""
376
```
377
378
### Combine with Assertions
379
```python
380
def typed_function(
381
data: chex.Array,
382
weights: chex.ArrayTree
383
) -> chex.Array:
384
# Runtime validation matches type annotations
385
chex.assert_type(data, chex.Array)
386
chex.assert_tree_has_only_ndarrays(weights)
387
388
return process_data(data, weights)
389
```
390
391
### Document Shape Expectations
392
```python
393
def attention_layer(
394
query: chex.Array, # (batch, seq_q, dim)
395
key: chex.Array, # (batch, seq_k, dim)
396
value: chex.Array # (batch, seq_k, dim)
397
) -> chex.Array: # (batch, seq_q, dim)
398
"""Multi-head attention with clear shape specifications."""
399
pass
400
```