or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

debugging.mddocs/

0

# Debugging and Development Utilities

1

2

Development tools for patching JAX functions with fake implementations, enabling easier debugging, testing in CPU-only environments, and controlled development workflows.

3

4

## Capabilities

5

6

### Fake JAX Transformations

7

8

Functions to replace JAX transformations with simpler implementations for debugging.

9

10

```python { .api }

11

def fake_jit(fn, **kwargs):

12

"""

13

Replace jax.jit with identity function for debugging.

14

15

Returns the original function without compilation, enabling:

16

- Step-through debugging with standard Python debuggers

17

- Faster iteration during development

18

- Access to intermediate values and Python control flow

19

20

Parameters:

21

- fn: Function that would normally be jitted

22

- **kwargs: Ignored (for compatibility with jax.jit signature)

23

24

Returns:

25

- Original function without jit compilation

26

"""

27

28

def fake_pmap(fn, axis_name=None, **kwargs):

29

"""

30

Replace jax.pmap with vmap for debugging on single device.

31

32

Enables testing of pmap code on machines without multiple devices

33

by replacing parallel mapping with vectorized mapping.

34

35

Parameters:

36

- fn: Function that would normally be pmapped

37

- axis_name: Axis name (ignored in fake implementation)

38

- **kwargs: Additional pmap arguments (most ignored)

39

40

Returns:

41

- Function wrapped with vmap instead of pmap

42

"""

43

44

def fake_pmap_and_jit(fn, **kwargs):

45

"""

46

Replace both jax.pmap and jax.jit with simpler implementations.

47

48

Combines fake_pmap and fake_jit behavior for comprehensive debugging

49

of functions that use both transformations.

50

51

Parameters:

52

- fn: Function to wrap

53

- **kwargs: Ignored transformation arguments

54

55

Returns:

56

- Function with both pmap and jit removed

57

"""

58

```

59

60

### Device Configuration

61

62

Functions for controlling device behavior in testing environments.

63

64

```python { .api }

65

def set_n_cpu_devices(n=None):

66

"""

67

Force XLA to use n CPU threads as host devices.

68

69

Enables testing of multi-device code (like pmap) on single-CPU machines

70

by creating multiple virtual CPU devices.

71

72

IMPORTANT: Must be called before any JAX operations or device queries.

73

74

Parameters:

75

- n: Number of CPU devices to create (uses FLAGS.chex_n_cpu_devices if None)

76

77

Raises:

78

- RuntimeError: If XLA backends are already initialized

79

"""

80

81

def get_n_cpu_devices_from_xla_flags():

82

"""

83

Parse number of CPU devices from XLA environment flags.

84

85

Returns:

86

- Number of CPU devices configured in XLA_FLAGS (default: 1)

87

"""

88

```

89

90

## Usage Examples

91

92

### Basic Debugging Setup

93

94

```python

95

import chex

96

import jax

97

import jax.numpy as jnp

98

99

# Original function with jit

100

@jax.jit

101

def compute_loss(params, data, labels):

102

predictions = jnp.dot(data, params['weights']) + params['bias']

103

return jnp.mean((predictions - labels) ** 2)

104

105

# For debugging, use fake_jit context manager

106

with chex.fake_jit():

107

# Now jax.jit calls become identity functions

108

@jax.jit # This becomes a no-op

109

def compute_loss_debug(params, data, labels):

110

predictions = jnp.dot(data, params['weights']) + params['bias']

111

# Can now set breakpoints and inspect intermediate values

112

print(f"Predictions shape: {predictions.shape}")

113

loss = jnp.mean((predictions - labels) ** 2)

114

print(f"Loss value: {loss}")

115

return loss

116

117

# Function executes without compilation

118

result = compute_loss_debug(params, data, labels)

119

```

120

121

### Testing Multi-Device Code

122

123

```python

124

# Setup multiple CPU devices for testing

125

chex.set_n_cpu_devices(4) # Must be called before any JAX operations

126

127

def parallel_computation(data):

128

"""Function designed to run on multiple devices."""

129

return jnp.sum(data, axis=-1)

130

131

# Test with fake_pmap

132

with chex.fake_pmap():

133

# pmap becomes vmap, works on single physical device

134

parallel_fn = jax.pmap(parallel_computation)

135

136

# Create data for 4 "devices"

137

batch_data = jnp.ones((4, 10, 5)) # (devices, batch, features)

138

result = parallel_fn(batch_data)

139

140

print(f"Result shape: {result.shape}") # (4, 10)

141

```

142

143

### Comprehensive Debugging Context

144

145

```python

146

def debug_training_step(state, batch):

147

"""Training step with comprehensive debugging."""

148

149

def loss_fn(params):

150

logits = apply_model(params, batch['inputs'])

151

return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(

152

logits=logits, labels=batch['labels']

153

))

154

155

# Compute loss and gradients

156

loss, grads = jax.value_and_grad(loss_fn)(state.params)

157

158

# Update parameters

159

new_params = update_params(state.params, grads, state.optimizer)

160

161

return state._replace(params=new_params), loss

162

163

# Use fake transformations for debugging

164

with chex.fake_pmap_and_jit():

165

# Both pmap and jit are disabled

166

@jax.pmap # Becomes vmap

167

@jax.jit # Becomes identity

168

def debug_step(state, batch):

169

return debug_training_step(state, batch)

170

171

# Can step through with debugger

172

new_state, loss = debug_step(training_state, data_batch)

173

```

174

175

### Conditional Debugging

176

177

```python

178

import os

179

180

DEBUG_MODE = os.getenv('DEBUG_JAX', '0') == '1'

181

182

def create_training_function():

183

if DEBUG_MODE:

184

# Development mode: disable transformations

185

context = chex.fake_pmap_and_jit()

186

else:

187

# Production mode: use real transformations

188

context = nullcontext()

189

190

with context:

191

@jax.pmap

192

@jax.jit

193

def train_step(state, batch):

194

# Training logic here

195

return updated_state, metrics

196

197

return train_step

198

199

# Usage

200

train_fn = create_training_function()

201

# Automatically uses fake or real transformations based on DEBUG_MODE

202

```

203

204

### Device Setup for Testing

205

206

```python

207

def setup_test_environment():

208

"""Setup consistent test environment across different machines."""

209

210

try:

211

# Try to set up multiple CPU devices for pmap testing

212

chex.set_n_cpu_devices(8)

213

print("Multi-device testing enabled")

214

return True

215

except RuntimeError as e:

216

print(f"Single-device testing only: {e}")

217

return False

218

219

def test_parallel_algorithm():

220

multi_device = setup_test_environment()

221

222

def algorithm(data):

223

return jnp.mean(data ** 2)

224

225

if multi_device:

226

# Test with real pmap

227

parallel_fn = jax.pmap(algorithm)

228

test_data = jnp.ones((8, 100)) # 8 devices, 100 features each

229

else:

230

# Test with fake pmap (becomes vmap)

231

with chex.fake_pmap():

232

parallel_fn = jax.pmap(algorithm)

233

test_data = jnp.ones((2, 100)) # Fewer "devices"

234

235

result = parallel_fn(test_data)

236

assert result.shape[0] == test_data.shape[0]

237

```

238

239

### Advanced Debugging Patterns

240

241

```python

242

class DebuggableModel:

243

"""Model class with built-in debugging support."""

244

245

def __init__(self, debug=False):

246

self.debug = debug

247

self._debug_context = chex.fake_jit() if debug else nullcontext()

248

249

def __enter__(self):

250

self._debug_context.__enter__()

251

return self

252

253

def __exit__(self, *args):

254

self._debug_context.__exit__(*args)

255

256

def forward(self, params, inputs):

257

with self._debug_context:

258

@jax.jit

259

def _forward(params, inputs):

260

# Model computation

261

hidden = jnp.dot(inputs, params['W1']) + params['b1']

262

if self.debug:

263

print(f"Hidden layer stats: mean={jnp.mean(hidden):.3f}")

264

265

hidden = jax.nn.relu(hidden)

266

output = jnp.dot(hidden, params['W2']) + params['b2']

267

268

if self.debug:

269

print(f"Output layer stats: mean={jnp.mean(output):.3f}")

270

271

return output

272

273

return _forward(params, inputs)

274

275

# Usage

276

with DebuggableModel(debug=True) as model:

277

predictions = model.forward(params, data)

278

# Prints intermediate statistics when debug=True

279

```

280

281

### Testing Framework Integration

282

283

```python

284

import unittest

285

286

class TestWithDebugging(unittest.TestCase):

287

288

def setUp(self):

289

# Setup CPU devices for consistent testing

290

try:

291

chex.set_n_cpu_devices(4)

292

self.multi_device = True

293

except RuntimeError:

294

self.multi_device = False

295

296

def test_jitted_function(self):

297

"""Test function behavior with and without jit."""

298

299

def compute_fn(x):

300

return x ** 2 + 2 * x + 1

301

302

x = jnp.array([1.0, 2.0, 3.0])

303

304

# Test without jit (easier debugging)

305

with chex.fake_jit():

306

jitted_fn = jax.jit(compute_fn)

307

result_fake = jitted_fn(x)

308

309

# Test with real jit

310

real_jitted_fn = jax.jit(compute_fn)

311

result_real = real_jitted_fn(x)

312

313

# Results should be identical

314

chex.assert_trees_all_close(result_fake, result_real)

315

316

def test_pmap_function(self):

317

"""Test pmap function with fake implementation."""

318

319

def parallel_sum(x):

320

return jnp.sum(x)

321

322

if self.multi_device:

323

# Test with real pmap

324

pmapped_fn = jax.pmap(parallel_sum)

325

test_data = jnp.ones((4, 10))

326

result = pmapped_fn(test_data)

327

expected_shape = (4,)

328

else:

329

# Test with fake pmap

330

with chex.fake_pmap():

331

pmapped_fn = jax.pmap(parallel_sum)

332

test_data = jnp.ones((2, 10))

333

result = pmapped_fn(test_data)

334

expected_shape = (2,)

335

336

self.assertEqual(result.shape, expected_shape)

337

```

338

339

## Key Features

340

341

### Non-Intrusive Debugging

342

- Use context managers to temporarily disable transformations

343

- Original code remains unchanged

344

- Easy to toggle between debug and production modes

345

346

### Multi-Device Testing

347

- Test pmap code on single-device machines

348

- Consistent behavior across different hardware configurations

349

- Simplified development workflow

350

351

### Step-Through Debugging

352

- Set breakpoints in jitted functions

353

- Inspect intermediate values

354

- Use standard Python debugging tools

355

356

### Performance Development

357

- Faster iteration during development

358

- Skip compilation during debugging

359

- Quick testing of algorithmic changes

360

361

## Best Practices

362

363

### Use Context Managers

364

```python

365

# Good: Use context managers for temporary debugging

366

with chex.fake_jit():

367

result = my_jitted_function(data)

368

369

# Avoid: Global patching that affects other code

370

```

371

372

### Set Up Devices Early

373

```python

374

# Good: Set up devices before any JAX operations

375

chex.set_n_cpu_devices(4)

376

import jax # JAX operations after device setup

377

378

# Avoid: Setting devices after JAX initialization

379

```

380

381

### Combine with Testing

382

```python

383

# Good: Use debugging utilities in tests

384

class MyTest(chex.TestCase):

385

def test_with_debugging(self):

386

with chex.fake_jit():

387

# Test logic here

388

pass

389

```

390

391

### Document Debug Modes

392

```python

393

def my_function(data, debug=False):

394

"""Process data with optional debugging.

395

396

Args:

397

data: Input data

398

debug: If True, disables jit for easier debugging

399

"""

400

context = chex.fake_jit() if debug else nullcontext()

401

with context:

402

# Function implementation

403

pass

404

```