or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

assertions.mddocs/

0

# Assertion Functions

1

2

Comprehensive validation utilities for JAX computations. These functions provide essential testing and debugging capabilities for validating tensor properties, shapes, values, and computational correctness in JAX programs.

3

4

## Capabilities

5

6

### Shape and Dimension Assertions

7

8

Functions for validating array shapes, dimensions, and structural properties.

9

10

```python { .api }

11

def assert_shape(array, expected_shape):

12

"""

13

Assert that array has the expected shape.

14

15

Parameters:

16

- array: Array to check

17

- expected_shape: Expected shape tuple, supports None for wildcard dimensions

18

"""

19

20

def assert_rank(array, expected_rank):

21

"""

22

Assert that array has the expected number of dimensions.

23

24

Parameters:

25

- array: Array to check

26

- expected_rank: Expected number of dimensions (int)

27

"""

28

29

def assert_size(array, expected_size):

30

"""

31

Assert that array has the expected total size.

32

33

Parameters:

34

- array: Array to check

35

- expected_size: Expected total number of elements (int)

36

"""

37

38

def assert_equal_shape(inputs, *, dims=None):

39

"""

40

Assert that all arrays have the same shape.

41

42

Parameters:

43

- inputs: Sequence of arrays to compare

44

- dims: Optional int or sequence of ints specifying which dimensions to compare

45

"""

46

47

def assert_equal_rank(inputs):

48

"""

49

Assert that all arrays have the same rank (number of dimensions).

50

51

Parameters:

52

- inputs: Sequence of arrays to compare

53

"""

54

55

def assert_equal_size(inputs):

56

"""

57

Assert that all arrays have the same total size.

58

59

Parameters:

60

- inputs: Sequence of arrays to compare

61

"""

62

63

def assert_equal_shape_prefix(inputs, prefix_len):

64

"""

65

Assert that the leading prefix_len dimensions of all inputs have same shape.

66

67

Parameters:

68

- inputs: Sequence of arrays to compare

69

- prefix_len: Number of leading dimensions to compare

70

"""

71

72

def assert_equal_shape_suffix(inputs, suffix_len):

73

"""

74

Assert that the final suffix_len dimensions of all inputs have same shape.

75

76

Parameters:

77

- inputs: Sequence of arrays to compare

78

- suffix_len: Number of trailing dimensions to compare

79

"""

80

```

81

82

### Axis-Specific Assertions

83

84

Functions for validating specific axis dimensions with comparison operators.

85

86

```python { .api }

87

def assert_axis_dimension(tensor, axis, expected):

88

"""

89

Assert that a specific axis has the expected dimension size.

90

91

Parameters:

92

- tensor: Array to check

93

- axis: Axis index to check

94

- expected: Expected dimension size for the axis

95

"""

96

97

def assert_axis_dimension_comparator(tensor, axis, pass_fn, error_string):

98

"""

99

Assert that pass_fn(tensor.shape[axis]) passes.

100

101

Used to implement ==, >, >=, <, <= checks.

102

103

Parameters:

104

- tensor: JAX array to check

105

- axis: Axis index to check

106

- pass_fn: Function that takes dimension size and returns bool

107

- error_string: Error message to display if assertion fails

108

"""

109

110

def assert_axis_dimension_gt(tensor, axis, val):

111

"""

112

Assert that axis dimension is greater than the given value.

113

114

Parameters:

115

- tensor: Array to check

116

- axis: Axis index to check

117

- val: Minimum size (exclusive)

118

"""

119

120

def assert_axis_dimension_gteq(tensor, axis, val):

121

"""

122

Assert that axis dimension is greater than or equal to the given value.

123

124

Parameters:

125

- tensor: Array to check

126

- axis: Axis index to check

127

- val: Minimum size (inclusive)

128

"""

129

130

def assert_axis_dimension_lt(tensor, axis, val):

131

"""

132

Assert that axis dimension is less than the given value.

133

134

Parameters:

135

- tensor: Array to check

136

- axis: Axis index to check

137

- val: Maximum size (exclusive)

138

"""

139

140

def assert_axis_dimension_lteq(tensor, axis, val):

141

"""

142

Assert that axis dimension is less than or equal to the given value.

143

144

Parameters:

145

- tensor: Array to check

146

- axis: Axis index to check

147

- val: Maximum size (inclusive)

148

"""

149

```

150

151

### Value and Content Assertions

152

153

Functions for validating array values and content properties.

154

155

```python { .api }

156

def assert_equal(first, second):

157

"""

158

Assert that two objects are equal as determined by the == operator.

159

160

Arrays with more than one element cannot be compared.

161

Use assert_trees_all_close to compare arrays.

162

163

Parameters:

164

- first: First object to compare

165

- second: Second object to compare

166

"""

167

168

def assert_scalar(value):

169

"""

170

Assert that value is a scalar (rank-0 array or Python scalar).

171

172

Parameters:

173

- value: Value to check

174

"""

175

176

def assert_scalar_in(value, options):

177

"""

178

Assert that scalar value is one of the given options.

179

180

Parameters:

181

- value: Scalar value to check

182

- options: Iterable of valid options

183

"""

184

185

def assert_scalar_positive(value):

186

"""

187

Assert that scalar value is positive (> 0).

188

189

Parameters:

190

- value: Scalar value to check

191

"""

192

193

def assert_scalar_non_negative(value):

194

"""

195

Assert that scalar value is non-negative (>= 0).

196

197

Parameters:

198

- value: Scalar value to check

199

"""

200

201

def assert_scalar_negative(value):

202

"""

203

Assert that scalar value is negative (< 0).

204

205

Parameters:

206

- value: Scalar value to check

207

"""

208

209

def assert_type(value, expected_type):

210

"""

211

Assert that value is of the expected type.

212

213

Parameters:

214

- value: Value to check

215

- expected_type: Expected type or tuple of types

216

"""

217

```

218

219

### Tree Structure Assertions

220

221

Functions for validating JAX pytree structures and their properties.

222

223

```python { .api }

224

def assert_tree_shape(tree, expected_shape):

225

"""

226

Assert that all arrays in the tree have the expected shape.

227

228

Parameters:

229

- tree: JAX pytree containing arrays

230

- expected_shape: Expected shape for all arrays in tree

231

"""

232

233

def assert_tree_shape_prefix(tree, prefix_shape):

234

"""

235

Assert that all arrays in tree have shapes starting with given prefix.

236

237

Parameters:

238

- tree: JAX pytree containing arrays

239

- prefix_shape: Shape prefix that all arrays should have

240

"""

241

242

def assert_tree_shape_suffix(tree, suffix_shape):

243

"""

244

Assert that all arrays in tree have shapes ending with given suffix.

245

246

Parameters:

247

- tree: JAX pytree containing arrays

248

- suffix_shape: Shape suffix that all arrays should have

249

"""

250

251

def assert_tree_all_finite(tree):

252

"""

253

Assert that all values in the tree are finite (not NaN or infinite).

254

255

Parameters:

256

- tree: JAX pytree containing arrays

257

"""

258

259

def assert_tree_has_only_ndarrays(tree):

260

"""

261

Assert that tree contains only numpy/JAX arrays.

262

263

Parameters:

264

- tree: JAX pytree to check

265

"""

266

267

def assert_tree_no_nones(tree):

268

"""

269

Assert that tree contains no None values.

270

271

Parameters:

272

- tree: JAX pytree to check

273

"""

274

275

def assert_tree_is_on_device(tree, device):

276

"""

277

Assert that all arrays in tree are on the specified device.

278

279

Parameters:

280

- tree: JAX pytree containing arrays

281

- device: Expected device

282

"""

283

284

def assert_tree_is_on_host(tree):

285

"""

286

Assert that all arrays in tree are on host (CPU).

287

288

Parameters:

289

- tree: JAX pytree containing arrays

290

"""

291

292

def assert_tree_is_sharded(tree):

293

"""

294

Assert that tree contains sharded arrays.

295

296

Parameters:

297

- tree: JAX pytree containing arrays

298

"""

299

```

300

301

### Multi-Tree Comparisons

302

303

Functions for comparing multiple JAX pytrees.

304

305

```python { .api }

306

def assert_trees_all_equal(*trees):

307

"""

308

Assert that all trees are exactly equal in structure and values.

309

310

Parameters:

311

- *trees: Variable number of JAX pytrees to compare

312

"""

313

314

def assert_trees_all_equal_comparator(tree1, tree2, comparator):

315

"""

316

Assert that two trees are equal using a custom comparator function.

317

318

Parameters:

319

- tree1, tree2: JAX pytrees to compare

320

- comparator: Function to compare individual array elements

321

"""

322

323

def assert_trees_all_equal_dtypes(*trees):

324

"""

325

Assert that all trees have matching data types.

326

327

Parameters:

328

- *trees: Variable number of JAX pytrees to compare

329

"""

330

331

def assert_trees_all_equal_shapes(*trees):

332

"""

333

Assert that all trees have matching shapes.

334

335

Parameters:

336

- *trees: Variable number of JAX pytrees to compare

337

"""

338

339

def assert_trees_all_equal_shapes_and_dtypes(*trees):

340

"""

341

Assert that all trees have matching shapes and data types.

342

343

Parameters:

344

- *trees: Variable number of JAX pytrees to compare

345

"""

346

347

def assert_trees_all_equal_sizes(*trees):

348

"""

349

Assert that all trees have matching sizes.

350

351

Parameters:

352

- *trees: Variable number of JAX pytrees to compare

353

"""

354

355

def assert_trees_all_equal_structs(*trees):

356

"""

357

Assert that all trees have matching structures (ignoring values).

358

359

Parameters:

360

- *trees: Variable number of JAX pytrees to compare

361

"""

362

363

def assert_trees_all_close(tree1, tree2, rtol=1e-05, atol=1e-08):

364

"""

365

Assert that trees are numerically close within tolerance.

366

367

Parameters:

368

- tree1, tree2: JAX pytrees to compare

369

- rtol: Relative tolerance

370

- atol: Absolute tolerance

371

"""

372

373

def assert_trees_all_close_ulp(tree1, tree2, maxulp=4):

374

"""

375

Assert that trees are close within Units in the Last Place tolerance.

376

377

Parameters:

378

- tree1, tree2: JAX pytrees to compare

379

- maxulp: Maximum units in the last place difference allowed

380

"""

381

```

382

383

### Device and Hardware Assertions

384

385

Functions for validating device availability and placement.

386

387

```python { .api }

388

def assert_devices_available(devices):

389

"""

390

Assert that specified devices are available.

391

392

Parameters:

393

- devices: List of device specifications or device objects

394

"""

395

396

def assert_gpu_available():

397

"""

398

Assert that at least one GPU device is available.

399

"""

400

401

def assert_tpu_available():

402

"""

403

Assert that at least one TPU device is available.

404

"""

405

```

406

407

### Utility Assertions

408

409

Helper functions for common validation patterns.

410

411

```python { .api }

412

def assert_exactly_one_is_none(*values):

413

"""

414

Assert that exactly one of the given values is None.

415

416

Parameters:

417

- *values: Variable number of values to check

418

"""

419

420

def assert_not_both_none(value1, value2):

421

"""

422

Assert that at least one of the two values is not None.

423

424

Parameters:

425

- value1, value2: Values to check

426

"""

427

428

def assert_is_broadcastable(shape1, shape2):

429

"""

430

Assert that two shapes are broadcastable according to NumPy rules.

431

432

Parameters:

433

- shape1, shape2: Shape tuples to check

434

"""

435

436

def assert_is_divisible(dividend, divisor):

437

"""

438

Assert that dividend is evenly divisible by divisor.

439

440

Parameters:

441

- dividend: Number to divide

442

- divisor: Number to divide by

443

"""

444

445

def assert_numerical_grads(fn, args, order=1, **kwargs):

446

"""

447

Assert that analytical gradients match numerical gradients.

448

449

Parameters:

450

- fn: Function to test gradients for

451

- args: Arguments to pass to function

452

- order: Order of derivative to test

453

- **kwargs: Additional arguments for numerical gradient computation

454

"""

455

```

456

457

### Assertion Control

458

459

Functions for controlling assertion behavior globally.

460

461

```python { .api }

462

def enable_asserts():

463

"""

464

Enable all Chex assertions (default state).

465

"""

466

467

def disable_asserts():

468

"""

469

Disable all Chex assertions for performance.

470

"""

471

472

def if_args_not_none(fn, *args, **kwargs):

473

"""

474

Execute assertion function only if all positional arguments are not None.

475

476

Parameters:

477

- fn: Assertion function to conditionally execute

478

- *args: Arguments to pass to fn

479

- **kwargs: Keyword arguments to pass to fn

480

"""

481

482

def clear_trace_counter():

483

"""

484

Clear the trace counter used by assert_max_traces.

485

"""

486

487

def assert_max_traces(fn, n):

488

"""

489

Decorator/wrapper to assert function is traced at most n times.

490

491

Parameters:

492

- fn: Function to wrap or n (number of max traces) if used as decorator

493

- n: Maximum number of traces allowed (if fn is a function)

494

495

Returns:

496

- Wrapped function or decorator

497

"""

498

```

499

500

## Usage Examples

501

502

### Basic Shape Validation

503

504

```python

505

import chex

506

import jax.numpy as jnp

507

508

# Create test arrays

509

x = jnp.array([[1, 2, 3], [4, 5, 6]]) # Shape: (2, 3)

510

y = jnp.zeros((2, 3))

511

512

# Validate shapes

513

chex.assert_shape(x, (2, 3)) # Passes

514

chex.assert_equal_shape([x, y]) # Passes - note list of arrays

515

chex.assert_rank(x, 2) # Passes

516

517

# Wildcard dimensions

518

z = jnp.ones((2, 5))

519

chex.assert_shape(z, (2, None)) # Passes - None matches any size

520

```

521

522

### Tree Validation

523

524

```python

525

# Create a pytree

526

tree = {

527

'weights': jnp.array([[1, 2], [3, 4]]),

528

'bias': jnp.array([0.1, 0.2]),

529

'nested': {'param': jnp.array([1.0])}

530

}

531

532

# Validate tree properties

533

chex.assert_tree_all_finite(tree)

534

chex.assert_tree_has_only_ndarrays(tree)

535

536

# Compare trees

537

tree2 = jax.tree_map(lambda x: x + 0.01, tree)

538

chex.assert_trees_all_close(tree, tree2, atol=0.02)

539

```

540

541

### Conditional Assertions

542

543

```python

544

def process_data(data, weights=None):

545

chex.assert_shape(data, (None, 10)) # Any batch size, 10 features

546

547

# Only check weights if provided

548

chex.if_args_not_none(chex.assert_shape, weights, (10, 5))

549

550

return data @ weights if weights is not None else data

551

```