0
# Debugging and Development Utilities
1
2
Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.
3
4
## Capabilities
5
6
### Fake JAX Transformations
7
8
Functions to replace JAX transformations with simpler implementations for debugging.
9
10
```python { .api }
11
def fake_jit(fn, **kwargs):
12
"""
13
Replace jax.jit with identity function for debugging.
14
15
Returns the original function without compilation, enabling:
16
- Step-through debugging with standard Python debuggers
17
- Faster iteration during development
18
- Access to intermediate values and Python control flow
19
20
Parameters:
21
- fn: Function that would normally be jitted
22
- **kwargs: Ignored (for compatibility with jax.jit signature)
23
24
Returns:
25
- Original function without jit compilation
26
"""
27
28
def fake_pmap(fn, axis_name=None, **kwargs):
29
"""
30
Replace jax.pmap with vmap for debugging on single device.
31
32
Enables testing of pmap code on machines without multiple devices
33
by replacing parallel mapping with vectorized mapping.
34
35
Parameters:
36
- fn: Function that would normally be pmapped
37
- axis_name: Axis name (ignored in fake implementation)
38
- **kwargs: Additional pmap arguments (most ignored)
39
40
Returns:
41
- Function wrapped with vmap instead of pmap
42
"""
43
44
def fake_pmap_and_jit(fn, **kwargs):
45
"""
46
Replace both jax.pmap and jax.jit with simpler implementations.
47
48
Combines fake_pmap and fake_jit behavior for comprehensive debugging
49
of functions that use both transformations.
50
51
Parameters:
52
- fn: Function to wrap
53
- **kwargs: Ignored transformation arguments
54
55
Returns:
56
- Function with both pmap and jit removed
57
"""
58
```
59
60
### Device Configuration
61
62
Functions for controlling device behavior in testing environments.
63
64
```python { .api }
65
def set_n_cpu_devices(n=None):
66
"""
67
Force XLA to use n CPU threads as host devices.
68
69
Enables testing of multi-device code (like pmap) on single-CPU machines
70
by creating multiple virtual CPU devices.
71
72
IMPORTANT: Must be called before any JAX operations or device queries.
73
74
Parameters:
75
- n: Number of CPU devices to create (uses FLAGS.chex_n_cpu_devices if None)
76
77
Raises:
78
- RuntimeError: If XLA backends are already initialized
79
"""
80
81
def get_n_cpu_devices_from_xla_flags():
82
"""
83
Parse number of CPU devices from XLA environment flags.
84
85
Returns:
86
- Number of CPU devices configured in XLA_FLAGS (default: 1)
87
"""
88
```
89
90
## Usage Examples
91
92
### Basic Debugging Setup
93
94
```python
95
import chex
96
import jax
97
import jax.numpy as jnp
98
99
# Original function with jit
100
@jax.jit
101
def compute_loss(params, data, labels):
102
predictions = jnp.dot(data, params['weights']) + params['bias']
103
return jnp.mean((predictions - labels) ** 2)
104
105
# For debugging, use fake_jit context manager
106
with chex.fake_jit():
107
# Now jax.jit calls become identity functions
108
@jax.jit # This becomes a no-op
109
def compute_loss_debug(params, data, labels):
110
predictions = jnp.dot(data, params['weights']) + params['bias']
111
# Can now set breakpoints and inspect intermediate values
112
print(f"Predictions shape: {predictions.shape}")
113
loss = jnp.mean((predictions - labels) ** 2)
114
print(f"Loss value: {loss}")
115
return loss
116
117
# Function executes without compilation
118
result = compute_loss_debug(params, data, labels)
119
```
120
121
### Testing Multi-Device Code
122
123
```python
124
# Setup multiple CPU devices for testing
125
chex.set_n_cpu_devices(4) # Must be called before any JAX operations
126
127
def parallel_computation(data):
128
"""Function designed to run on multiple devices."""
129
return jnp.sum(data, axis=-1)
130
131
# Test with fake_pmap
132
with chex.fake_pmap():
133
# pmap becomes vmap, works on single physical device
134
parallel_fn = jax.pmap(parallel_computation)
135
136
# Create data for 4 "devices"
137
batch_data = jnp.ones((4, 10, 5)) # (devices, batch, features)
138
result = parallel_fn(batch_data)
139
140
print(f"Result shape: {result.shape}") # (4, 10)
141
```
142
143
### Comprehensive Debugging Context
144
145
```python
146
def debug_training_step(state, batch):
147
"""Training step with comprehensive debugging."""
148
149
def loss_fn(params):
150
logits = apply_model(params, batch['inputs'])
151
return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(
152
logits=logits, labels=batch['labels']
153
))
154
155
# Compute loss and gradients
156
loss, grads = jax.value_and_grad(loss_fn)(state.params)
157
158
# Update parameters
159
new_params = update_params(state.params, grads, state.optimizer)
160
161
return state._replace(params=new_params), loss
162
163
# Use fake transformations for debugging
164
with chex.fake_pmap_and_jit():
165
# Both pmap and jit are disabled
166
@jax.pmap # Becomes vmap
167
@jax.jit # Becomes identity
168
def debug_step(state, batch):
169
return debug_training_step(state, batch)
170
171
# Can step through with debugger
172
new_state, loss = debug_step(training_state, data_batch)
173
```
174
175
### Conditional Debugging
176
177
```python
178
import os
179
180
DEBUG_MODE = os.getenv('DEBUG_JAX', '0') == '1'
181
182
def create_training_function():
183
if DEBUG_MODE:
184
# Development mode: disable transformations
185
context = chex.fake_pmap_and_jit()
186
else:
187
# Production mode: use real transformations
188
context = nullcontext()
189
190
with context:
191
@jax.pmap
192
@jax.jit
193
def train_step(state, batch):
194
# Training logic here
195
return updated_state, metrics
196
197
return train_step
198
199
# Usage
200
train_fn = create_training_function()
201
# Automatically uses fake or real transformations based on DEBUG_MODE
202
```
203
204
### Device Setup for Testing
205
206
```python
207
def setup_test_environment():
208
"""Setup consistent test environment across different machines."""
209
210
try:
211
# Try to set up multiple CPU devices for pmap testing
212
chex.set_n_cpu_devices(8)
213
print("Multi-device testing enabled")
214
return True
215
except RuntimeError as e:
216
print(f"Single-device testing only: {e}")
217
return False
218
219
def test_parallel_algorithm():
220
multi_device = setup_test_environment()
221
222
def algorithm(data):
223
return jnp.mean(data ** 2)
224
225
if multi_device:
226
# Test with real pmap
227
parallel_fn = jax.pmap(algorithm)
228
test_data = jnp.ones((8, 100)) # 8 devices, 100 features each
229
else:
230
# Test with fake pmap (becomes vmap)
231
with chex.fake_pmap():
232
parallel_fn = jax.pmap(algorithm)
233
test_data = jnp.ones((2, 100)) # Fewer "devices"
234
235
result = parallel_fn(test_data)
236
assert result.shape[0] == test_data.shape[0]
237
```
238
239
### Advanced Debugging Patterns
240
241
```python
242
class DebuggableModel:
243
"""Model class with built-in debugging support."""
244
245
def __init__(self, debug=False):
246
self.debug = debug
247
self._debug_context = chex.fake_jit() if debug else nullcontext()
248
249
def __enter__(self):
250
self._debug_context.__enter__()
251
return self
252
253
def __exit__(self, *args):
254
self._debug_context.__exit__(*args)
255
256
def forward(self, params, inputs):
257
with self._debug_context:
258
@jax.jit
259
def _forward(params, inputs):
260
# Model computation
261
hidden = jnp.dot(inputs, params['W1']) + params['b1']
262
if self.debug:
263
print(f"Hidden layer stats: mean={jnp.mean(hidden):.3f}")
264
265
hidden = jax.nn.relu(hidden)
266
output = jnp.dot(hidden, params['W2']) + params['b2']
267
268
if self.debug:
269
print(f"Output layer stats: mean={jnp.mean(output):.3f}")
270
271
return output
272
273
return _forward(params, inputs)
274
275
# Usage
276
with DebuggableModel(debug=True) as model:
277
predictions = model.forward(params, data)
278
# Prints intermediate statistics when debug=True
279
```
280
281
### Testing Framework Integration
282
283
```python
284
import unittest
285
286
class TestWithDebugging(unittest.TestCase):
287
288
def setUp(self):
289
# Setup CPU devices for consistent testing
290
try:
291
chex.set_n_cpu_devices(4)
292
self.multi_device = True
293
except RuntimeError:
294
self.multi_device = False
295
296
def test_jitted_function(self):
297
"""Test function behavior with and without jit."""
298
299
def compute_fn(x):
300
return x ** 2 + 2 * x + 1
301
302
x = jnp.array([1.0, 2.0, 3.0])
303
304
# Test without jit (easier debugging)
305
with chex.fake_jit():
306
jitted_fn = jax.jit(compute_fn)
307
result_fake = jitted_fn(x)
308
309
# Test with real jit
310
real_jitted_fn = jax.jit(compute_fn)
311
result_real = real_jitted_fn(x)
312
313
# Results should be identical
314
chex.assert_trees_all_close(result_fake, result_real)
315
316
def test_pmap_function(self):
317
"""Test pmap function with fake implementation."""
318
319
def parallel_sum(x):
320
return jnp.sum(x)
321
322
if self.multi_device:
323
# Test with real pmap
324
pmapped_fn = jax.pmap(parallel_sum)
325
test_data = jnp.ones((4, 10))
326
result = pmapped_fn(test_data)
327
expected_shape = (4,)
328
else:
329
# Test with fake pmap
330
with chex.fake_pmap():
331
pmapped_fn = jax.pmap(parallel_sum)
332
test_data = jnp.ones((2, 10))
333
result = pmapped_fn(test_data)
334
expected_shape = (2,)
335
336
self.assertEqual(result.shape, expected_shape)
337
```
338
339
## Key Features
340
341
### Non-Intrusive Debugging
342
- Use context managers to temporarily disable transformations
343
- Original code remains unchanged
344
- Easy to toggle between debug and production modes
345
346
### Multi-Device Testing
347
- Test pmap code on single-device machines
348
- Consistent behavior across different hardware configurations
349
- Simplified development workflow
350
351
### Step-Through Debugging
352
- Set breakpoints in jitted functions
353
- Inspect intermediate values
354
- Use standard Python debugging tools
355
356
### Performance Development
357
- Faster iteration during development
358
- Skip compilation during debugging
359
- Quick testing of algorithmic changes
360
361
## Best Practices
362
363
### Use Context Managers
364
```python
365
# Good: Use context managers for temporary debugging
366
with chex.fake_jit():
367
result = my_jitted_function(data)
368
369
# Avoid: Global patching that affects other code
370
```
371
372
### Set Up Devices Early
373
```python
374
# Good: Set up devices before any JAX operations
375
chex.set_n_cpu_devices(4)
376
import jax # JAX operations after device setup
377
378
# Avoid: Setting devices after JAX initialization
379
```
380
381
### Combine with Testing
382
```python
383
# Good: Use debugging utilities in tests
384
class MyTest(chex.TestCase):
385
def test_with_debugging(self):
386
with chex.fake_jit():
387
# Test logic here
388
pass
389
```
390
391
### Document Debug Modes
392
```python
393
def my_function(data, debug=False):
394
"""Process data with optional debugging.
395
396
Args:
397
data: Input data
398
debug: If True, disables jit for easier debugging
399
"""
400
context = chex.fake_jit() if debug else nullcontext()
401
with context:
402
# Function implementation
403
pass
404
```