0
# Advanced Features
1
2
Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.
3
4
## Capabilities
5
6
### Backend Restriction
7
8
Context manager for controlling JAX backend compilation and device usage.
9
10
```python { .api }
11
def restrict_backends(*, allowed=None, forbidden=None):
12
"""
13
Context manager that prevents JAX compilation for specified backends.
14
15
Useful for ensuring code runs only on intended devices or catching
16
accidental compilation on restricted hardware.
17
18
Parameters:
19
- allowed: Sequence of allowed backend platform names (e.g., ['cpu', 'gpu'])
20
- forbidden: Sequence of forbidden backend platform names
21
22
Yields:
23
- Context where compilation for forbidden platforms raises RestrictedBackendError
24
25
Raises:
26
- ValueError: If neither allowed nor forbidden specified, or if conflicts exist
27
- RestrictedBackendError: If compilation attempted on restricted backend
28
"""
29
30
class RestrictedBackendError(RuntimeError):
31
"""
32
Exception raised when compilation attempted on restricted backend.
33
"""
34
```
35
36
### Dimension Mapping
37
38
Utility class for managing named dimensions and shape specifications.
39
40
```python { .api }
41
class Dimensions:
42
"""
43
Lightweight utility that maps strings to shape tuples.
44
45
Enables readable shape specifications using named dimensions
46
and supports dimension arithmetic and wildcard dimensions.
47
48
Examples:
49
>>> dims = chex.Dimensions(B=3, T=5, N=7)
50
>>> dims['NBT'] # (7, 3, 5)
51
>>> dims['(BT)N'] # (15, 7) - flattened dimensions
52
>>> dims['BT*'] # (3, 5, None) - wildcard dimension
53
"""
54
55
def __init__(self, **kwargs):
56
"""
57
Initialize dimensions with named size mappings.
58
59
Parameters:
60
- **kwargs: Dimension name to size mappings (e.g., B=32, T=100)
61
"""
62
63
def __getitem__(self, key):
64
"""
65
Get shape tuple for dimension string specification.
66
67
Parameters:
68
- key: String specifying dimensions (e.g., 'BTC', '(BT)C', 'BT*')
69
70
Returns:
71
- Tuple of integers and/or None for wildcard dimensions
72
"""
73
74
def __setitem__(self, key, value):
75
"""
76
Set dimension sizes from shape tuple.
77
78
Parameters:
79
- key: String specifying dimensions
80
- value: Shape tuple to assign to dimensions
81
"""
82
83
def size(self, key):
84
"""
85
Get total size (product) of specified dimensions.
86
87
Parameters:
88
- key: String specifying dimensions
89
90
Returns:
91
- Total number of elements in the specified shape
92
"""
93
```
94
95
### Jittable Assertions
96
97
Advanced assertion system that works inside jitted functions using JAX checkify.
98
99
```python { .api }
100
def chexify(
101
fn,
102
async_check=True,
103
errors=ChexifyChecks.user
104
):
105
"""
106
Enable Chex value assertions inside jitted functions.
107
108
Wraps function to enable runtime assertions that work with JAX transformations
109
by using JAX's checkify system for delayed error checking.
110
111
Parameters:
112
- fn: Function to wrap with jittable assertions
113
- async_check: Whether to check errors asynchronously
114
- errors: Set of error categories to check (from ChexifyChecks)
115
116
Returns:
117
- Wrapped function that supports Chex assertions inside jit
118
"""
119
120
def with_jittable_assertions(fn):
121
"""
122
Decorator for enabling jittable assertions in a function.
123
124
Equivalent to chexify(fn) but as a decorator.
125
126
Parameters:
127
- fn: Function to decorate
128
129
Returns:
130
- Function with jittable assertions enabled
131
"""
132
133
def block_until_chexify_assertions_complete():
134
"""
135
Wait for all asynchronous assertion checks to complete.
136
137
Should be called after computations that use chexify to ensure
138
all assertion errors are properly surfaced.
139
"""
140
141
class ChexifyChecks:
142
"""
143
Collection of checkify error categories for jittable assertions.
144
145
Attributes:
146
- user: User-defined checks (Chex assertions)
147
- nan: NaN detection checks
148
- index: Array indexing checks
149
- div: Division by zero checks
150
- float: Floating point error checks
151
- automatic: Automatically enabled checks
152
- all: All available checks
153
"""
154
```
155
156
### Deprecation Management
157
158
Utilities for managing deprecated functions and warning users about API changes.
159
160
```python { .api }
161
def warn_deprecated_function(fun, replacement=None):
162
"""
163
Decorator to mark a function as deprecated.
164
165
Emits DeprecationWarning when the decorated function is called.
166
167
Parameters:
168
- fun: Function to mark as deprecated
169
- replacement: Optional name of replacement function
170
171
Returns:
172
- Wrapped function that emits deprecation warning
173
"""
174
175
def create_deprecated_function_alias(fun, new_name, deprecated_alias):
176
"""
177
Create a deprecated alias for a function.
178
179
Creates a new function that emits deprecation warning and delegates
180
to the original function.
181
182
Parameters:
183
- fun: Original function
184
- new_name: Current name of the function
185
- deprecated_alias: Deprecated alias name
186
187
Returns:
188
- Deprecated alias function
189
"""
190
191
def warn_only_n_pos_args_in_future(fun, n):
192
"""
193
Warn if more than n positional arguments are passed.
194
195
Helps transition functions to keyword-only arguments by warning
196
when too many positional arguments are used.
197
198
Parameters:
199
- fun: Function to wrap
200
- n: Maximum number of allowed positional arguments
201
202
Returns:
203
- Wrapped function that warns about excess positional arguments
204
"""
205
206
def warn_keyword_args_only_in_future(fun):
207
"""
208
Warn if any positional arguments are passed (keyword-only transition).
209
210
Equivalent to warn_only_n_pos_args_in_future(fun, 0).
211
212
Parameters:
213
- fun: Function to wrap
214
215
Returns:
216
- Wrapped function that warns about positional arguments
217
"""
218
```
219
220
## Usage Examples
221
222
### Backend Restriction
223
224
```python
225
import chex
226
import jax
227
import jax.numpy as jnp
228
229
# Ensure computation only runs on CPU
230
with chex.restrict_backends(allowed=['cpu']):
231
@jax.jit
232
def cpu_only_computation(x):
233
return x ** 2
234
235
result = cpu_only_computation(jnp.array([1, 2, 3]))
236
# Works fine - compiles for CPU
237
238
# Prevent accidental GPU usage
239
with chex.restrict_backends(forbidden=['gpu', 'tpu']):
240
try:
241
@jax.jit(device=jax.devices('gpu')[0]) # Attempt GPU compilation
242
def gpu_computation(x):
243
return x + 1
244
245
gpu_computation(jnp.array([1]))
246
except chex.RestrictedBackendError:
247
print("GPU compilation blocked as expected")
248
249
# Restrict during specific phases
250
def training_phase(model_fn, data):
251
# Ensure training only uses CPUs (e.g., for memory reasons)
252
with chex.restrict_backends(allowed=['cpu']):
253
return model_fn(data)
254
255
def inference_phase(model_fn, data):
256
# Allow inference on any available device
257
return model_fn(data)
258
```
259
260
### Dimension Mapping
261
262
```python
263
import chex
264
import jax.numpy as jnp
265
266
# Create dimension mapping for transformer model
267
dims = chex.Dimensions(
268
B=32, # Batch size
269
T=512, # Sequence length
270
D=768, # Model dimension
271
H=12, # Number of heads
272
V=50000 # Vocabulary size
273
)
274
275
# Use dimensions for shape assertions
276
def transformer_layer(
277
inputs, # Shape: (B, T, D)
278
weights_qkv, # Shape: (D, 3*D)
279
weights_out # Shape: (D, D)
280
):
281
# Validate input shapes using dimension names
282
chex.assert_shape(inputs, dims['BTD'])
283
chex.assert_shape(weights_qkv, (dims.D, 3 * dims.D))
284
chex.assert_shape(weights_out, dims['DD'])
285
286
# Compute attention
287
batch_size, seq_len, model_dim = inputs.shape
288
289
# Query, Key, Value projections
290
qkv = jnp.dot(inputs, weights_qkv) # (B, T, 3*D)
291
qkv = qkv.reshape(batch_size, seq_len, 3, dims.H, dims.D // dims.H)
292
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
293
294
# Multi-head attention computation...
295
# Output shape should be (B, T, D)
296
output = jnp.dot(attention_output, weights_out)
297
298
chex.assert_shape(output, dims['BTD'])
299
return output
300
301
# Dynamic dimension updates
302
def process_variable_batch(data):
303
# Update batch dimension based on actual data
304
dims['B'] = data.shape[0]
305
306
# Use updated dimensions
307
chex.assert_shape(data, dims['BTD'])
308
return data
309
310
# Flattened dimensions for linear layers
311
def create_classifier_weights():
312
# Flatten sequence and model dimensions
313
input_size = dims.size('TD') # T * D = 512 * 768
314
output_size = dims.V # Vocabulary size
315
316
return jnp.ones((input_size, output_size))
317
318
# Wildcard dimensions
319
def flexible_attention(queries, keys, values):
320
# Allow any sequence length but fixed model dimension
321
chex.assert_shape(queries, dims['B*D']) # (B, any_seq_len, D)
322
chex.assert_shape(keys, dims['B*D']) # (B, any_seq_len, D)
323
chex.assert_shape(values, dims['B*D']) # (B, any_seq_len, D)
324
325
# Attention computation...
326
return attention_output
327
```
328
329
### Jittable Assertions
330
331
```python
332
import chex
333
import jax
334
import jax.numpy as jnp
335
336
# Enable assertions inside jitted functions
337
@chex.chexify # or @chex.with_jittable_assertions
338
@jax.jit
339
def safe_division(x, y):
340
# These assertions work inside jit!
341
chex.assert_tree_all_finite(x)
342
chex.assert_tree_all_finite(y)
343
chex.assert_scalar_positive(y) # Ensure no division by zero
344
345
result = x / y
346
chex.assert_tree_all_finite(result)
347
return result
348
349
# Use with async checking
350
@chex.chexify(async_check=True)
351
@jax.jit
352
def training_step(params, batch):
353
# Assertions are checked asynchronously
354
chex.assert_tree_all_finite(params)
355
chex.assert_shape(batch['inputs'], (32, 784))
356
357
# Training computation...
358
loss = compute_loss(params, batch)
359
grads = jax.grad(compute_loss)(params, batch)
360
361
chex.assert_tree_all_finite(grads)
362
chex.assert_scalar_positive(loss)
363
364
return grads, loss
365
366
# Block until all assertions complete
367
for epoch in range(num_epochs):
368
for batch in dataloader:
369
grads, loss = training_step(params, batch)
370
params = update_params(params, grads)
371
372
# Ensure all assertions from epoch have been checked
373
chex.block_until_chexify_assertions_complete()
374
print(f"Epoch {epoch} completed successfully")
375
376
# Configure error categories
377
@chex.chexify(errors=chex.ChexifyChecks.all) # Check everything
378
@jax.jit
379
def comprehensive_checks(data):
380
# Enables NaN, indexing, division, and user checks
381
return jnp.mean(data)
382
383
@chex.chexify(errors=chex.ChexifyChecks.user | chex.ChexifyChecks.nan)
384
@jax.jit
385
def custom_checks(data):
386
# Only user assertions and NaN checks
387
return jnp.sum(data)
388
```
389
390
### Deprecation Management
391
392
```python
393
import chex
394
395
# Mark function as deprecated
396
@chex.warn_deprecated_function(replacement='new_function_name')
397
def old_function(x):
398
"""This function is deprecated."""
399
return x + 1
400
401
# Create deprecated alias
402
def current_function(x, y):
403
return x * y
404
405
# Create deprecated alias that warns users
406
old_function_name = chex.create_deprecated_function_alias(
407
current_function,
408
'current_function',
409
'old_function_name'
410
)
411
412
# Transition to keyword-only arguments
413
@chex.warn_only_n_pos_args_in_future(n=1)
414
def transitioning_function(required_arg, optional_arg=None, another_arg=None):
415
"""Function transitioning to keyword-only arguments."""
416
return required_arg + (optional_arg or 0) + (another_arg or 0)
417
418
# Usage that will warn:
419
# transitioning_function(1, 2, 3) # Warning: only first arg should be positional
420
421
# Preferred usage:
422
# transitioning_function(1, optional_arg=2, another_arg=3) # No warning
423
424
# Force keyword-only
425
@chex.warn_keyword_args_only_in_future
426
def keyword_only_function(*, arg1, arg2):
427
"""Function that should only accept keyword arguments."""
428
return arg1 + arg2
429
430
# This will warn:
431
# keyword_only_function(1, 2) # Warning about positional args
432
433
# This is correct:
434
# keyword_only_function(arg1=1, arg2=2) # No warning
435
```
436
437
### Advanced Integration Patterns
438
439
```python
440
import chex
441
import jax
442
import jax.numpy as jnp
443
444
class AdvancedTrainer:
445
"""Training class with advanced Chex features."""
446
447
def __init__(self, config):
448
self.config = config
449
450
# Set up dimensions
451
self.dims = chex.Dimensions(
452
B=config.batch_size,
453
T=config.sequence_length,
454
D=config.model_dim,
455
C=config.num_classes
456
)
457
458
# Configure backend restrictions
459
self.allowed_backends = config.allowed_backends
460
461
@chex.chexify(async_check=True)
462
def create_training_step(self):
463
"""Create jittable training step with assertions."""
464
465
def training_step(state, batch):
466
# Validate inputs
467
chex.assert_tree_all_finite(state.params)
468
chex.assert_shape(batch['inputs'], self.dims['BTD'])
469
chex.assert_shape(batch['labels'], self.dims['BC'])
470
471
# Forward pass
472
def loss_fn(params):
473
logits = self.model.apply(params, batch['inputs'])
474
chex.assert_shape(logits, self.dims['BC'])
475
return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
476
logits=logits, labels=batch['labels']
477
))
478
479
loss, grads = jax.value_and_grad(loss_fn)(state.params)
480
481
# Validate outputs
482
chex.assert_scalar_positive(loss)
483
chex.assert_tree_all_finite(grads)
484
485
# Update state
486
new_state = self.optimizer.update(grads, state)
487
return new_state, {'loss': loss}
488
489
return jax.jit(training_step)
490
491
def train(self, train_data):
492
"""Training loop with backend restriction."""
493
494
# Restrict to allowed backends during training
495
with chex.restrict_backends(allowed=self.allowed_backends):
496
training_step = self.create_training_step()
497
498
for epoch in range(self.config.num_epochs):
499
for step, batch in enumerate(train_data):
500
# Validate batch dimensions dynamically
501
actual_batch_size = batch['inputs'].shape[0]
502
if actual_batch_size != self.dims.B:
503
# Update dimensions for final batch
504
self.dims['B'] = actual_batch_size
505
506
state, metrics = training_step(self.state, batch)
507
self.state = state
508
509
if step % 100 == 0:
510
# Ensure all async assertions have completed
511
chex.block_until_chexify_assertions_complete()
512
self.log_metrics(metrics, epoch, step)
513
514
# Integration with existing codebases
515
def modernize_legacy_function():
516
"""Example of gradually modernizing legacy code."""
517
518
# Original function (deprecated)
519
@chex.warn_deprecated_function(replacement='process_data_v2')
520
def process_data_v1(data, normalize, scale):
521
return data * scale if normalize else data
522
523
# New function with better API
524
@chex.warn_only_n_pos_args_in_future(n=1)
525
def process_data_v2(data, *, normalize=False, scale=1.0):
526
# Add shape validation
527
chex.assert_rank(data, 2)
528
chex.assert_scalar_positive(scale)
529
530
if normalize:
531
data = data / jnp.linalg.norm(data, axis=1, keepdims=True)
532
533
return data * scale
534
535
# Future version (keyword-only)
536
def process_data_v3(*, data, normalize=False, scale=1.0):
537
# Enhanced with jittable assertions
538
@chex.chexify
539
@jax.jit
540
def _process(data, normalize, scale):
541
chex.assert_rank(data, 2)
542
chex.assert_scalar_positive(scale)
543
chex.assert_tree_all_finite(data)
544
545
if normalize:
546
norms = jnp.linalg.norm(data, axis=1, keepdims=True)
547
chex.assert_tree_all_finite(norms)
548
data = data / norms
549
550
result = data * scale
551
chex.assert_tree_all_finite(result)
552
return result
553
554
return _process(data, normalize, scale)
555
```
556
557
## Key Features
558
559
### Fine-Grained Control
560
- Precise backend restrictions for different computation phases
561
- Flexible dimension management with arithmetic operations
562
- Configurable assertion checking with multiple error categories
563
564
### Production Ready
565
- Async assertion checking for minimal performance impact
566
- Deprecation management for smooth API transitions
567
- Integration with existing JAX transformation pipeline
568
569
### Developer Friendly
570
- Clear error messages and warnings
571
- Readable dimension specifications
572
- Comprehensive debugging support
573
574
## Best Practices
575
576
### Use Backend Restrictions Strategically
577
```python
578
# Good: Restrict during specific phases
579
with chex.restrict_backends(allowed=['cpu']):
580
# Memory-intensive preprocessing
581
pass
582
583
# Avoid: Overly broad restrictions
584
with chex.restrict_backends(forbidden=['gpu']):
585
# Entire training loop - might be unnecessarily restrictive
586
pass
587
```
588
589
### Design Maintainable Dimension Systems
590
```python
591
# Good: Centralized dimension management
592
dims = chex.Dimensions(B=32, T=100, D=512)
593
594
# Good: Clear dimension naming
595
dims = chex.Dimensions(
596
batch_size=32,
597
sequence_length=100,
598
embedding_dim=512
599
)
600
```
601
602
### Plan Deprecation Carefully
603
```python
604
# Good: Provide clear migration path
605
@chex.warn_deprecated_function(replacement='new_api_function')
606
def old_function():
607
pass
608
609
# Good: Gradual transition
610
@chex.warn_only_n_pos_args_in_future(n=1)
611
def transitioning_function(required, *, optional=None):
612
pass
613
```