or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

advanced.mdassertions.mddataclasses.mddebugging.mdindex.mdtesting.mdtypes.md

testing.mddocs/

0

# Test Variants and Testing Infrastructure

1

2

Testing framework that enables running the same test code across multiple JAX execution variants (jitted vs non-jitted, different devices, with pmap, etc.) for comprehensive validation of JAX code behavior.

3

4

## Capabilities

5

6

### Test Base Classes

7

8

Base test case class providing variant testing infrastructure.

9

10

```python { .api }

11

class TestCase(parameterized.TestCase):

12

"""

13

Base class for Chex tests that use variants.

14

15

Provides infrastructure for running tests across multiple JAX execution modes.

16

Subclasses from absl.testing.parameterized.TestCase to support generator unrolling.

17

"""

18

19

def variant(self, *args, **kwargs):

20

"""

21

Access the current test variant function.

22

23

This method is dynamically replaced by the @variants decorator

24

with the appropriate transformation (jit, identity, etc.).

25

26

Raises:

27

- RuntimeError: If called without @variants decorator

28

"""

29

```

30

31

### Variant Decorators

32

33

Decorators for running tests across multiple execution modes.

34

35

```python { .api }

36

def variants(*variant_types):

37

"""

38

Decorator to run test across specified variants.

39

40

Parameters:

41

- *variant_types: ChexVariantType values specifying which variants to test

42

43

Returns:

44

- Generator yielding one test per variant

45

46

Example:

47

@variants(ChexVariantType.WITH_JIT, ChexVariantType.WITHOUT_JIT)

48

def test_function(self):

49

fn = self.variant(my_function)

50

# Test implementation

51

"""

52

53

def all_variants(*variant_types):

54

"""

55

Decorator to run test across all available variants.

56

57

Parameters:

58

- *variant_types: Optional variant types to include (defaults to all)

59

60

Returns:

61

- Generator yielding one test per variant

62

"""

63

```

64

65

### Variant Types

66

67

Enumeration of available test variant types.

68

69

```python { .api }

70

class ChexVariantType(Enum):

71

"""

72

Enumeration of available Chex test variants.

73

74

Use self.variant.type to get the type of the current test variant.

75

"""

76

77

WITH_JIT = 1 # Function wrapped with jax.jit

78

WITHOUT_JIT = 2 # Function executed directly (identity)

79

WITH_DEVICE = 3 # Function executed on specific device

80

WITHOUT_DEVICE = 4 # Function executed on default device

81

WITH_PMAP = 5 # Function wrapped with jax.pmap

82

```

83

84

### Parameter Generation

85

86

Utilities for generating test parameter combinations.

87

88

```python { .api }

89

def params_product(*params_lists, named=False):

90

"""

91

Generate cartesian product of parameter lists for parameterized tests.

92

93

Parameters:

94

- *params_lists: Sequences of parameter values

95

- named: Whether to generate test names for parameterized.named_parameters

96

97

Returns:

98

- Sequence of parameter combinations

99

100

Example:

101

# Generate all combinations of batch sizes and learning rates

102

params = params_product([32, 64], [0.01, 0.001])

103

# [(32, 0.01), (32, 0.001), (64, 0.01), (64, 0.001)]

104

"""

105

```

106

107

## Usage Examples

108

109

### Basic Variant Testing

110

111

```python

112

import chex

113

import jax

114

import jax.numpy as jnp

115

116

class MyTest(chex.TestCase):

117

118

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)

119

def test_addition_function(self):

120

def add_one(x):

121

return x + 1

122

123

# Get the variant-appropriate version of the function

124

fn = self.variant(add_one)

125

126

# Test the function

127

result = fn(jnp.array([1, 2, 3]))

128

expected = jnp.array([2, 3, 4])

129

130

chex.assert_equal(result, expected)

131

132

# Access variant type if needed

133

if self.variant.type == chex.ChexVariantType.WITH_JIT:

134

# This test is running with jit

135

pass

136

```

137

138

### Testing All Variants

139

140

```python

141

class ComprehensiveTest(chex.TestCase):

142

143

@chex.all_variants

144

def test_matrix_multiply(self):

145

def matmul(a, b):

146

return jnp.dot(a, b)

147

148

fn = self.variant(matmul)

149

150

a = jnp.array([[1, 2], [3, 4]])

151

b = jnp.array([[5, 6], [7, 8]])

152

153

result = fn(a, b)

154

expected = jnp.array([[19, 22], [43, 50]])

155

156

chex.assert_equal(result, expected)

157

```

158

159

### Parameterized Variant Testing

160

161

```python

162

from absl.testing import parameterized

163

164

class ParameterizedVariantTest(chex.TestCase):

165

166

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)

167

@parameterized.parameters(

168

{'batch_size': 32, 'input_dim': 784},

169

{'batch_size': 64, 'input_dim': 1024},

170

)

171

def test_neural_network_layer(self, batch_size, input_dim):

172

def linear_layer(x, weights, bias):

173

return jnp.dot(x, weights) + bias

174

175

fn = self.variant(linear_layer)

176

177

# Create test data

178

x = jnp.ones((batch_size, input_dim))

179

weights = jnp.ones((input_dim, 10))

180

bias = jnp.zeros(10)

181

182

result = fn(x, weights, bias)

183

184

# Verify output shape

185

chex.assert_shape(result, (batch_size, 10))

186

187

# Verify computation

188

expected = jnp.full((batch_size, 10), input_dim)

189

chex.assert_equal(result, expected)

190

```

191

192

### Using Parameter Products

193

194

```python

195

class ProductTest(chex.TestCase):

196

197

# Generate all combinations of optimizers and learning rates

198

@parameterized.parameters(

199

*chex.params_product(

200

['sgd', 'adam', 'rmsprop'],

201

[0.1, 0.01, 0.001],

202

named=True

203

)

204

)

205

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)

206

def test_optimizer_update(self, optimizer_name, learning_rate):

207

def update_step(params, grads, lr):

208

return params - lr * grads

209

210

fn = self.variant(update_step)

211

212

params = jnp.array([1.0, 2.0, 3.0])

213

grads = jnp.array([0.1, 0.2, 0.3])

214

215

updated_params = fn(params, grads, learning_rate)

216

expected = params - learning_rate * grads

217

218

chex.assert_trees_all_close(updated_params, expected)

219

```

220

221

### Testing with Device Variants

222

223

```python

224

class DeviceTest(chex.TestCase):

225

226

@chex.variants(

227

chex.ChexVariantType.WITH_DEVICE,

228

chex.ChexVariantType.WITHOUT_DEVICE

229

)

230

def test_device_placement(self):

231

def compute_sum(x):

232

return jnp.sum(x)

233

234

fn = self.variant(compute_sum)

235

236

x = jnp.array([1, 2, 3, 4, 5])

237

result = fn(x)

238

239

chex.assert_equal(result, 15)

240

241

# Can check device placement if needed

242

if hasattr(result, 'device'):

243

# Verify device placement based on variant type

244

pass

245

```

246

247

### Testing with Pmap Variants

248

249

```python

250

class PmapTest(chex.TestCase):

251

252

@chex.variants(chex.ChexVariantType.WITH_PMAP)

253

def test_parallel_computation(self):

254

def parallel_square(x):

255

return x ** 2

256

257

fn = self.variant(parallel_square)

258

259

# Create data for multiple devices

260

n_devices = jax.local_device_count()

261

x = jnp.arange(n_devices * 4).reshape(n_devices, 4)

262

263

result = fn(x)

264

expected = x ** 2

265

266

chex.assert_equal(result, expected)

267

```

268

269

### Advanced Variant Usage

270

271

```python

272

class AdvancedVariantTest(chex.TestCase):

273

274

def setUp(self):

275

super().setUp()

276

# Setup that runs before each variant

277

self.tolerance = 1e-6

278

279

@chex.all_variants

280

def test_gradient_computation(self):

281

def loss_fn(params, data):

282

return jnp.sum((params['w'] @ data - params['b']) ** 2)

283

284

# Get variant-appropriate version

285

loss_fn = self.variant(loss_fn)

286

287

# Create test data

288

params = {'w': jnp.array([[1.0, 2.0]]), 'b': jnp.array([0.5])}

289

data = jnp.array([[1.0], [2.0]])

290

291

# Compute gradients

292

grad_fn = jax.grad(loss_fn)

293

grads = grad_fn(params, data)

294

295

# Verify gradient structure matches params

296

chex.assert_trees_all_equal_structs(grads, params)

297

298

# Verify gradients are finite

299

chex.assert_tree_all_finite(grads)

300

301

def test_variant_type_specific_behavior(self):

302

"""Test that demonstrates variant-specific testing logic."""

303

304

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)

305

def _test_impl(self):

306

def expensive_computation(x):

307

# Some computation that might behave differently jitted/non-jitted

308

return jnp.sum(jnp.sin(x) * jnp.cos(x))

309

310

fn = self.variant(expensive_computation)

311

x = jnp.linspace(0, 2 * jnp.pi, 1000)

312

result = fn(x)

313

314

# Different expectations based on variant type

315

if self.variant.type == chex.ChexVariantType.WITH_JIT:

316

# Jitted version might have slightly different numerical behavior

317

chex.assert_scalar(result)

318

else:

319

# Non-jitted version

320

chex.assert_scalar(result)

321

322

# Execute the test

323

_test_impl(self)

324

```

325

326

## Key Features

327

328

### Comprehensive Coverage

329

- Tests same logic across multiple execution modes

330

- Catches bugs that only appear in specific configurations

331

- Ensures consistent behavior between jitted and non-jitted code

332

333

### Easy Integration

334

- Drop-in replacement for standard test classes

335

- Works with existing parameterized testing frameworks

336

- Minimal changes to existing test code

337

338

### Flexible Configuration

339

- Choose specific variants or test all

340

- Combine with parameterized testing

341

- Support for device-specific testing

342

343

### Debugging Support

344

- Access to variant type within tests

345

- Clear error messages when variants fail

346

- Integration with Chex assertion framework

347

348

## Best Practices

349

350

### Use Meaningful Test Names

351

```python

352

@chex.variants(chex.ChexVariantType.WITH_JIT, chex.ChexVariantType.WITHOUT_JIT)

353

def test_neural_network_forward_pass_consistency(self):

354

# Clear test purpose

355

pass

356

```

357

358

### Test Critical Paths

359

```python

360

# Focus variant testing on functions that will be jitted in practice

361

@chex.all_variants

362

def test_training_step(self):

363

# This will be jitted in real usage

364

pass

365

```

366

367

### Combine with Assertions

368

```python

369

@chex.all_variants

370

def test_with_comprehensive_checks(self):

371

fn = self.variant(my_function)

372

result = fn(input_data)

373

374

# Use Chex assertions for thorough validation

375

chex.assert_shape(result, expected_shape)

376

chex.assert_tree_all_finite(result)

377

```