or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

advanced.mddocs/

0

# Advanced Features

1

2

Specialized utilities including backend restriction, dimension mapping, jittable assertions, and deprecation management for advanced JAX development scenarios.

3

4

## Capabilities

5

6

### Backend Restriction

7

8

Context manager for controlling JAX backend compilation and device usage.

9

10

```python { .api }

11

def restrict_backends(*, allowed=None, forbidden=None):

12

"""

13

Context manager that prevents JAX compilation for specified backends.

14

15

Useful for ensuring code runs only on intended devices or catching

16

accidental compilation on restricted hardware.

17

18

Parameters:

19

- allowed: Sequence of allowed backend platform names (e.g., ['cpu', 'gpu'])

20

- forbidden: Sequence of forbidden backend platform names

21

22

Yields:

23

- Context where compilation for forbidden platforms raises RestrictedBackendError

24

25

Raises:

26

- ValueError: If neither allowed nor forbidden specified, or if conflicts exist

27

- RestrictedBackendError: If compilation attempted on restricted backend

28

"""

29

30

class RestrictedBackendError(RuntimeError):

31

"""

32

Exception raised when compilation attempted on restricted backend.

33

"""

34

```

35

36

### Dimension Mapping

37

38

Utility class for managing named dimensions and shape specifications.

39

40

```python { .api }

41

class Dimensions:

42

"""

43

Lightweight utility that maps strings to shape tuples.

44

45

Enables readable shape specifications using named dimensions

46

and supports dimension arithmetic and wildcard dimensions.

47

48

Examples:

49

>>> dims = chex.Dimensions(B=3, T=5, N=7)

50

>>> dims['NBT'] # (7, 3, 5)

51

>>> dims['(BT)N'] # (15, 7) - flattened dimensions

52

>>> dims['BT*'] # (3, 5, None) - wildcard dimension

53

"""

54

55

def __init__(self, **kwargs):

56

"""

57

Initialize dimensions with named size mappings.

58

59

Parameters:

60

- **kwargs: Dimension name to size mappings (e.g., B=32, T=100)

61

"""

62

63

def __getitem__(self, key):

64

"""

65

Get shape tuple for dimension string specification.

66

67

Parameters:

68

- key: String specifying dimensions (e.g., 'BTC', '(BT)C', 'BT*')

69

70

Returns:

71

- Tuple of integers and/or None for wildcard dimensions

72

"""

73

74

def __setitem__(self, key, value):

75

"""

76

Set dimension sizes from shape tuple.

77

78

Parameters:

79

- key: String specifying dimensions

80

- value: Shape tuple to assign to dimensions

81

"""

82

83

def size(self, key):

84

"""

85

Get total size (product) of specified dimensions.

86

87

Parameters:

88

- key: String specifying dimensions

89

90

Returns:

91

- Total number of elements in the specified shape

92

"""

93

```

94

95

### Jittable Assertions

96

97

Advanced assertion system that works inside jitted functions using JAX checkify.

98

99

```python { .api }

100

def chexify(

101

fn,

102

async_check=True,

103

errors=ChexifyChecks.user

104

):

105

"""

106

Enable Chex value assertions inside jitted functions.

107

108

Wraps function to enable runtime assertions that work with JAX transformations

109

by using JAX's checkify system for delayed error checking.

110

111

Parameters:

112

- fn: Function to wrap with jittable assertions

113

- async_check: Whether to check errors asynchronously

114

- errors: Set of error categories to check (from ChexifyChecks)

115

116

Returns:

117

- Wrapped function that supports Chex assertions inside jit

118

"""

119

120

def with_jittable_assertions(fn):

121

"""

122

Decorator for enabling jittable assertions in a function.

123

124

Equivalent to chexify(fn) but as a decorator.

125

126

Parameters:

127

- fn: Function to decorate

128

129

Returns:

130

- Function with jittable assertions enabled

131

"""

132

133

def block_until_chexify_assertions_complete():

134

"""

135

Wait for all asynchronous assertion checks to complete.

136

137

Should be called after computations that use chexify to ensure

138

all assertion errors are properly surfaced.

139

"""

140

141

class ChexifyChecks:

142

"""

143

Collection of checkify error categories for jittable assertions.

144

145

Attributes:

146

- user: User-defined checks (Chex assertions)

147

- nan: NaN detection checks

148

- index: Array indexing checks

149

- div: Division by zero checks

150

- float: Floating point error checks

151

- automatic: Automatically enabled checks

152

- all: All available checks

153

"""

154

```

155

156

### Deprecation Management

157

158

Utilities for managing deprecated functions and warning users about API changes.

159

160

```python { .api }

161

def warn_deprecated_function(fun, replacement=None):

162

"""

163

Decorator to mark a function as deprecated.

164

165

Emits DeprecationWarning when the decorated function is called.

166

167

Parameters:

168

- fun: Function to mark as deprecated

169

- replacement: Optional name of replacement function

170

171

Returns:

172

- Wrapped function that emits deprecation warning

173

"""

174

175

def create_deprecated_function_alias(fun, new_name, deprecated_alias):

176

"""

177

Create a deprecated alias for a function.

178

179

Creates a new function that emits deprecation warning and delegates

180

to the original function.

181

182

Parameters:

183

- fun: Original function

184

- new_name: Current name of the function

185

- deprecated_alias: Deprecated alias name

186

187

Returns:

188

- Deprecated alias function

189

"""

190

191

def warn_only_n_pos_args_in_future(fun, n):

192

"""

193

Warn if more than n positional arguments are passed.

194

195

Helps transition functions to keyword-only arguments by warning

196

when too many positional arguments are used.

197

198

Parameters:

199

- fun: Function to wrap

200

- n: Maximum number of allowed positional arguments

201

202

Returns:

203

- Wrapped function that warns about excess positional arguments

204

"""

205

206

def warn_keyword_args_only_in_future(fun):

207

"""

208

Warn if any positional arguments are passed (keyword-only transition).

209

210

Equivalent to warn_only_n_pos_args_in_future(fun, 0).

211

212

Parameters:

213

- fun: Function to wrap

214

215

Returns:

216

- Wrapped function that warns about positional arguments

217

"""

218

```

219

220

## Usage Examples

221

222

### Backend Restriction

223

224

```python

225

import chex

226

import jax

227

import jax.numpy as jnp

228

229

# Ensure computation only runs on CPU

230

with chex.restrict_backends(allowed=['cpu']):

231

@jax.jit

232

def cpu_only_computation(x):

233

return x ** 2

234

235

result = cpu_only_computation(jnp.array([1, 2, 3]))

236

# Works fine - compiles for CPU

237

238

# Prevent accidental GPU usage

239

with chex.restrict_backends(forbidden=['gpu', 'tpu']):

240

try:

241

@jax.jit(device=jax.devices('gpu')[0]) # Attempt GPU compilation

242

def gpu_computation(x):

243

return x + 1

244

245

gpu_computation(jnp.array([1]))

246

except chex.RestrictedBackendError:

247

print("GPU compilation blocked as expected")

248

249

# Restrict during specific phases

250

def training_phase(model_fn, data):

251

# Ensure training only uses CPUs (e.g., for memory reasons)

252

with chex.restrict_backends(allowed=['cpu']):

253

return model_fn(data)

254

255

def inference_phase(model_fn, data):

256

# Allow inference on any available device

257

return model_fn(data)

258

```

259

260

### Dimension Mapping

261

262

```python

263

import chex

264

import jax.numpy as jnp

265

266

# Create dimension mapping for transformer model

267

dims = chex.Dimensions(

268

B=32, # Batch size

269

T=512, # Sequence length

270

D=768, # Model dimension

271

H=12, # Number of heads

272

V=50000 # Vocabulary size

273

)

274

275

# Use dimensions for shape assertions

276

def transformer_layer(

277

inputs, # Shape: (B, T, D)

278

weights_qkv, # Shape: (D, 3*D)

279

weights_out # Shape: (D, D)

280

):

281

# Validate input shapes using dimension names

282

chex.assert_shape(inputs, dims['BTD'])

283

chex.assert_shape(weights_qkv, (dims.D, 3 * dims.D))

284

chex.assert_shape(weights_out, dims['DD'])

285

286

# Compute attention

287

batch_size, seq_len, model_dim = inputs.shape

288

289

# Query, Key, Value projections

290

qkv = jnp.dot(inputs, weights_qkv) # (B, T, 3*D)

291

qkv = qkv.reshape(batch_size, seq_len, 3, dims.H, dims.D // dims.H)

292

q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

293

294

# Multi-head attention computation...

295

# Output shape should be (B, T, D)

296

output = jnp.dot(attention_output, weights_out)

297

298

chex.assert_shape(output, dims['BTD'])

299

return output

300

301

# Dynamic dimension updates

302

def process_variable_batch(data):

303

# Update batch dimension based on actual data

304

dims['B'] = data.shape[0]

305

306

# Use updated dimensions

307

chex.assert_shape(data, dims['BTD'])

308

return data

309

310

# Flattened dimensions for linear layers

311

def create_classifier_weights():

312

# Flatten sequence and model dimensions

313

input_size = dims.size('TD') # T * D = 512 * 768

314

output_size = dims.V # Vocabulary size

315

316

return jnp.ones((input_size, output_size))

317

318

# Wildcard dimensions

319

def flexible_attention(queries, keys, values):

320

# Allow any sequence length but fixed model dimension

321

chex.assert_shape(queries, dims['B*D']) # (B, any_seq_len, D)

322

chex.assert_shape(keys, dims['B*D']) # (B, any_seq_len, D)

323

chex.assert_shape(values, dims['B*D']) # (B, any_seq_len, D)

324

325

# Attention computation...

326

return attention_output

327

```

328

329

### Jittable Assertions

330

331

```python

332

import chex

333

import jax

334

import jax.numpy as jnp

335

336

# Enable assertions inside jitted functions

337

@chex.chexify # or @chex.with_jittable_assertions

338

@jax.jit

339

def safe_division(x, y):

340

# These assertions work inside jit!

341

chex.assert_tree_all_finite(x)

342

chex.assert_tree_all_finite(y)

343

chex.assert_scalar_positive(y) # Ensure no division by zero

344

345

result = x / y

346

chex.assert_tree_all_finite(result)

347

return result

348

349

# Use with async checking

350

@chex.chexify(async_check=True)

351

@jax.jit

352

def training_step(params, batch):

353

# Assertions are checked asynchronously

354

chex.assert_tree_all_finite(params)

355

chex.assert_shape(batch['inputs'], (32, 784))

356

357

# Training computation...

358

loss = compute_loss(params, batch)

359

grads = jax.grad(compute_loss)(params, batch)

360

361

chex.assert_tree_all_finite(grads)

362

chex.assert_scalar_positive(loss)

363

364

return grads, loss

365

366

# Block until all assertions complete

367

for epoch in range(num_epochs):

368

for batch in dataloader:

369

grads, loss = training_step(params, batch)

370

params = update_params(params, grads)

371

372

# Ensure all assertions from epoch have been checked

373

chex.block_until_chexify_assertions_complete()

374

print(f"Epoch {epoch} completed successfully")

375

376

# Configure error categories

377

@chex.chexify(errors=chex.ChexifyChecks.all) # Check everything

378

@jax.jit

379

def comprehensive_checks(data):

380

# Enables NaN, indexing, division, and user checks

381

return jnp.mean(data)

382

383

@chex.chexify(errors=chex.ChexifyChecks.user | chex.ChexifyChecks.nan)

384

@jax.jit

385

def custom_checks(data):

386

# Only user assertions and NaN checks

387

return jnp.sum(data)

388

```

389

390

### Deprecation Management

391

392

```python

393

import chex

394

395

# Mark function as deprecated

396

@chex.warn_deprecated_function(replacement='new_function_name')

397

def old_function(x):

398

"""This function is deprecated."""

399

return x + 1

400

401

# Create deprecated alias

402

def current_function(x, y):

403

return x * y

404

405

# Create deprecated alias that warns users

406

old_function_name = chex.create_deprecated_function_alias(

407

current_function,

408

'current_function',

409

'old_function_name'

410

)

411

412

# Transition to keyword-only arguments

413

@chex.warn_only_n_pos_args_in_future(n=1)

414

def transitioning_function(required_arg, optional_arg=None, another_arg=None):

415

"""Function transitioning to keyword-only arguments."""

416

return required_arg + (optional_arg or 0) + (another_arg or 0)

417

418

# Usage that will warn:

419

# transitioning_function(1, 2, 3) # Warning: only first arg should be positional

420

421

# Preferred usage:

422

# transitioning_function(1, optional_arg=2, another_arg=3) # No warning

423

424

# Force keyword-only

425

@chex.warn_keyword_args_only_in_future

426

def keyword_only_function(*, arg1, arg2):

427

"""Function that should only accept keyword arguments."""

428

return arg1 + arg2

429

430

# This will warn:

431

# keyword_only_function(1, 2) # Warning about positional args

432

433

# This is correct:

434

# keyword_only_function(arg1=1, arg2=2) # No warning

435

```

436

437

### Advanced Integration Patterns

438

439

```python

440

import chex

441

import jax

442

import jax.numpy as jnp

443

444

class AdvancedTrainer:

445

"""Training class with advanced Chex features."""

446

447

def __init__(self, config):

448

self.config = config

449

450

# Set up dimensions

451

self.dims = chex.Dimensions(

452

B=config.batch_size,

453

T=config.sequence_length,

454

D=config.model_dim,

455

C=config.num_classes

456

)

457

458

# Configure backend restrictions

459

self.allowed_backends = config.allowed_backends

460

461

@chex.chexify(async_check=True)

462

def create_training_step(self):

463

"""Create jittable training step with assertions."""

464

465

def training_step(state, batch):

466

# Validate inputs

467

chex.assert_tree_all_finite(state.params)

468

chex.assert_shape(batch['inputs'], self.dims['BTD'])

469

chex.assert_shape(batch['labels'], self.dims['BC'])

470

471

# Forward pass

472

def loss_fn(params):

473

logits = self.model.apply(params, batch['inputs'])

474

chex.assert_shape(logits, self.dims['BC'])

475

return jnp.mean(jax.nn.softmax_cross_entropy_with_logits(

476

logits=logits, labels=batch['labels']

477

))

478

479

loss, grads = jax.value_and_grad(loss_fn)(state.params)

480

481

# Validate outputs

482

chex.assert_scalar_positive(loss)

483

chex.assert_tree_all_finite(grads)

484

485

# Update state

486

new_state = self.optimizer.update(grads, state)

487

return new_state, {'loss': loss}

488

489

return jax.jit(training_step)

490

491

def train(self, train_data):

492

"""Training loop with backend restriction."""

493

494

# Restrict to allowed backends during training

495

with chex.restrict_backends(allowed=self.allowed_backends):

496

training_step = self.create_training_step()

497

498

for epoch in range(self.config.num_epochs):

499

for step, batch in enumerate(train_data):

500

# Validate batch dimensions dynamically

501

actual_batch_size = batch['inputs'].shape[0]

502

if actual_batch_size != self.dims.B:

503

# Update dimensions for final batch

504

self.dims['B'] = actual_batch_size

505

506

state, metrics = training_step(self.state, batch)

507

self.state = state

508

509

if step % 100 == 0:

510

# Ensure all async assertions have completed

511

chex.block_until_chexify_assertions_complete()

512

self.log_metrics(metrics, epoch, step)

513

514

# Integration with existing codebases

515

def modernize_legacy_function():

516

"""Example of gradually modernizing legacy code."""

517

518

# Original function (deprecated)

519

@chex.warn_deprecated_function(replacement='process_data_v2')

520

def process_data_v1(data, normalize, scale):

521

return data * scale if normalize else data

522

523

# New function with better API

524

@chex.warn_only_n_pos_args_in_future(n=1)

525

def process_data_v2(data, *, normalize=False, scale=1.0):

526

# Add shape validation

527

chex.assert_rank(data, 2)

528

chex.assert_scalar_positive(scale)

529

530

if normalize:

531

data = data / jnp.linalg.norm(data, axis=1, keepdims=True)

532

533

return data * scale

534

535

# Future version (keyword-only)

536

def process_data_v3(*, data, normalize=False, scale=1.0):

537

# Enhanced with jittable assertions

538

@chex.chexify

539

@jax.jit

540

def _process(data, normalize, scale):

541

chex.assert_rank(data, 2)

542

chex.assert_scalar_positive(scale)

543

chex.assert_tree_all_finite(data)

544

545

if normalize:

546

norms = jnp.linalg.norm(data, axis=1, keepdims=True)

547

chex.assert_tree_all_finite(norms)

548

data = data / norms

549

550

result = data * scale

551

chex.assert_tree_all_finite(result)

552

return result

553

554

return _process(data, normalize, scale)

555

```

556

557

## Key Features

558

559

### Fine-Grained Control

560

- Precise backend restrictions for different computation phases

561

- Flexible dimension management with arithmetic operations

562

- Configurable assertion checking with multiple error categories

563

564

### Production Ready

565

- Async assertion checking for minimal performance impact

566

- Deprecation management for smooth API transitions

567

- Integration with existing JAX transformation pipeline

568

569

### Developer Friendly

570

- Clear error messages and warnings

571

- Readable dimension specifications

572

- Comprehensive debugging support

573

574

## Best Practices

575

576

### Use Backend Restrictions Strategically

577

```python

578

# Good: Restrict during specific phases

579

with chex.restrict_backends(allowed=['cpu']):

580

# Memory-intensive preprocessing

581

pass

582

583

# Avoid: Overly broad restrictions

584

with chex.restrict_backends(forbidden=['gpu']):

585

# Entire training loop - might be unnecessarily restrictive

586

pass

587

```

588

589

### Design Maintainable Dimension Systems

590

```python

591

# Good: Centralized dimension management

592

dims = chex.Dimensions(B=32, T=100, D=512)

593

594

# Good: Clear dimension naming

595

dims = chex.Dimensions(

596

batch_size=32,

597

sequence_length=100,

598

embedding_dim=512

599

)

600

```

601

602

### Plan Deprecation Carefully

603

```python

604

# Good: Provide clear migration path

605

@chex.warn_deprecated_function(replacement='new_api_function')

606

def old_function():

607

pass

608

609

# Good: Gradual transition

610

@chex.warn_only_n_pos_args_in_future(n=1)

611

def transitioning_function(required, *, optional=None):

612

pass

613

```