or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

experimental.mddocs/

0

# Experimental Features

1

2

JAX experimental features provide access to cutting-edge capabilities, performance optimizations, and research functionality through `jax.experimental`. These features may change or be moved to the main JAX API in future versions.

3

4

**Warning**: Experimental APIs may change without notice between JAX versions. Use with caution in production code.

5

6

## Core Imports

7

8

```python

9

import jax.experimental as jex

10

from jax.experimental import io_callback, enable_x64

11

```

12

13

## Capabilities

14

15

### Precision Control

16

17

Control floating-point precision globally across JAX computations.

18

19

```python { .api }

20

def enable_x64(enable: bool = True) -> None:

21

"""

22

Enable or disable 64-bit floating point precision.

23

24

Args:

25

enable: Whether to enable 64-bit precision (default: True)

26

27

Note:

28

This sets jax_enable_x64 config flag globally

29

"""

30

31

def disable_x64() -> None:

32

"""

33

Disable 64-bit floating point precision.

34

35

Convenience function equivalent to enable_x64(False).

36

"""

37

```

38

39

Usage examples:

40

```python

41

# Enable double precision

42

jax.experimental.enable_x64()

43

x = jnp.array(1.0) # Now defaults to float64 instead of float32

44

print(x.dtype) # dtype('float64')

45

46

# Disable double precision

47

jax.experimental.disable_x64()

48

y = jnp.array(1.0) # Back to float32

49

print(y.dtype) # dtype('float32')

50

```

51

52

### I/O and Callbacks

53

54

Enable host callbacks for I/O operations and side effects within JAX computations.

55

56

```python { .api }

57

def io_callback(

58

callback: Callable,

59

result_shape_dtypes,

60

*args,

61

sharding=None,

62

vmap_method=None,

63

ordered=False,

64

**kwargs

65

) -> Any:

66

"""

67

Call host function from within JAX computation with I/O side effects.

68

69

Args:

70

callback: Host function to call (should be pure except for I/O)

71

result_shape_dtypes: Shape and dtype specification for callback result

72

args: Arguments to pass to callback

73

sharding: Sharding specification for result

74

vmap_method: How to handle vmapping ('sequential', 'expand_dims', etc.)

75

ordered: Whether to maintain call ordering across devices

76

kwargs: Additional keyword arguments for callback

77

78

Returns:

79

Result of callback with specified shape and dtype

80

"""

81

```

82

83

Usage examples:

84

```python

85

# Logging during computation (debugging)

86

def log_value(x, step):

87

print(f"Step {step}: value = {x}")

88

return x

89

90

@jax.jit

91

def training_step(x, step):

92

# Log intermediate values during training

93

x = jax.experimental.io_callback(

94

log_value,

95

jax.ShapeDtypeStruct(x.shape, x.dtype),

96

x, step

97

)

98

return x * 2

99

100

# File I/O during computation

101

def save_checkpoint(params, step):

102

import pickle

103

with open(f'checkpoint_{step}.pkl', 'wb') as f:

104

pickle.dump(params, f)

105

return step

106

107

@jax.jit

108

def train_with_checkpointing(params, data, step):

109

# Training computation

110

loss = compute_loss(params, data)

111

grads = jax.grad(compute_loss)(params, data)

112

new_params = update_params(params, grads)

113

114

# Save checkpoint every 100 steps

115

step = jax.experimental.io_callback(

116

save_checkpoint,

117

jax.ShapeDtypeStruct((), jnp.int32),

118

new_params, step

119

)

120

121

return new_params, loss

122

```

123

124

### Advanced Differentiation

125

126

Experimental differentiation features and optimizations.

127

128

```python { .api }

129

def saved_input_vjp(f, *primals) -> tuple[Any, Callable]:

130

"""

131

Vector-Jacobian product with saved inputs for memory efficiency.

132

133

Args:

134

f: Function to differentiate

135

primals: Input values

136

137

Returns:

138

Tuple of (primal_out, vjp_fun) where vjp_fun has access to saved inputs

139

"""

140

141

# Alias for saved_input_vjp

142

si_vjp = saved_input_vjp

143

```

144

145

Usage example:

146

```python

147

def expensive_function(x, y):

148

# Some expensive computation that we want to differentiate

149

z = jnp.exp(x) + jnp.sin(y)

150

return jnp.sum(z ** 2)

151

152

# Use saved input VJP for memory efficiency

153

x, y = jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])

154

primal_out, vjp_fn = jax.experimental.saved_input_vjp(expensive_function, x, y)

155

156

# Compute VJP with cotangent

157

cotangent = 1.0

158

x_grad, y_grad = vjp_fn(cotangent)

159

```

160

161

### Extended Array Types

162

163

Experimental array types and extended functionality.

164

165

```python { .api }

166

class EArray:

167

"""

168

Extended array type with additional metadata and functionality.

169

170

Experimental array type that may include additional features

171

beyond standard JAX arrays.

172

"""

173

pass

174

175

class MutableArray:

176

"""

177

Experimental mutable array type for specific use cases.

178

179

Warning: Breaks JAX's functional programming model. Use carefully.

180

"""

181

pass

182

183

def mutable_array(init_val) -> MutableArray:

184

"""

185

Create mutable array from initial value.

186

187

Args:

188

init_val: Initial array value

189

190

Returns:

191

MutableArray that can be modified in-place

192

"""

193

```

194

195

### Type System Extensions

196

197

Experimental extensions to JAX's type system.

198

199

```python { .api }

200

def primal_tangent_dtype(primal_dtype, tangent_dtype=None):

201

"""

202

Create dtype for primal-tangent pairs in forward-mode AD.

203

204

Args:

205

primal_dtype: Data type for primal values

206

tangent_dtype: Data type for tangent values (defaults to primal_dtype)

207

208

Returns:

209

Combined dtype for primal-tangent computation

210

"""

211

```

212

213

### Compilation and Performance

214

215

Experimental compilation features and performance optimizations.

216

217

```python { .api }

218

# Compilation control

219

def disable_jit_cache() -> None:

220

"""Disable JIT compilation cache for debugging."""

221

222

def enable_jit_cache() -> None:

223

"""Re-enable JIT compilation cache."""

224

225

# Performance monitoring

226

def compilation_cache_stats() -> dict:

227

"""Get statistics about JIT compilation cache."""

228

229

def clear_compilation_cache() -> None:

230

"""Clear JIT compilation cache."""

231

```

232

233

### Hardware-Specific Features

234

235

Experimental features for specific hardware accelerators.

236

237

```python { .api }

238

# TPU-specific features

239

class TPUMemoryFraction:

240

"""Control TPU memory usage fraction."""

241

242

def set_tpu_memory_fraction(fraction: float) -> None:

243

"""

244

Set fraction of TPU memory to use.

245

246

Args:

247

fraction: Memory fraction (0.0 to 1.0)

248

"""

249

250

# GPU-specific features

251

def gpu_memory_stats() -> dict:

252

"""Get GPU memory usage statistics."""

253

254

def set_gpu_memory_growth(enable: bool) -> None:

255

"""

256

Enable/disable GPU memory growth.

257

258

Args:

259

enable: Whether to enable incremental memory allocation

260

"""

261

```

262

263

### Automatic Mixed Precision

264

265

Experimental automatic mixed precision for training acceleration.

266

267

```python { .api }

268

class AutoMixedPrecision:

269

"""Automatic mixed precision policy for training."""

270

271

def __init__(self, policy='float16'):

272

"""

273

Initialize AMP policy.

274

275

Args:

276

policy: Precision policy ('float16', 'bfloat16', etc.)

277

"""

278

self.policy = policy

279

280

def __call__(self, fn):

281

"""Apply AMP to function."""

282

pass

283

284

def amp_policy(policy_name: str) -> AutoMixedPrecision:

285

"""

286

Create automatic mixed precision policy.

287

288

Args:

289

policy_name: Name of precision policy

290

291

Returns:

292

AMP policy object

293

"""

294

```

295

296

### Distributed Computing Extensions

297

298

Experimental distributed computing features beyond standard pmap/shard_map.

299

300

```python { .api }

301

def multi_host_utils():

302

"""Utilities for multi-host distributed computation."""

303

pass

304

305

class GlobalDeviceArray:

306

"""

307

Experimental global device array for large-scale distributed computation.

308

309

Represents arrays that span multiple hosts in distributed setting.

310

"""

311

pass

312

313

def create_global_device_array(

314

shape,

315

dtype,

316

mesh,

317

partition_spec

318

) -> GlobalDeviceArray:

319

"""

320

Create global device array across distributed system.

321

322

Args:

323

shape: Global array shape

324

dtype: Array data type

325

mesh: Device mesh specification

326

partition_spec: How to partition array

327

328

Returns:

329

Global device array

330

"""

331

```

332

333

### Research and Prototype Features

334

335

Cutting-edge research features that may be highly experimental.

336

337

```python { .api }

338

# Sparsity support

339

class SparseArray:

340

"""Experimental sparse array support."""

341

pass

342

343

def sparse_ops():

344

"""Sparse operations module (highly experimental)."""

345

pass

346

347

# Quantization support

348

def quantized_dot(lhs, rhs, **kwargs):

349

"""Experimental quantized matrix multiplication."""

350

pass

351

352

def quantization_utils():

353

"""Utilities for quantized computation."""

354

pass

355

356

# Custom operators

357

def custom_op_builder():

358

"""Builder for custom XLA operations."""

359

pass

360

361

# Advanced compilation

362

def ahead_of_time_compile(fn, *args, **kwargs):

363

"""Ahead-of-time compilation (experimental)."""

364

pass

365

```

366

367

### Debugging and Profiling

368

369

Experimental debugging and profiling tools.

370

371

```python { .api }

372

def debug_callback(callback, *args, **kwargs):

373

"""

374

Debug callback that doesn't affect computation graph.

375

376

Args:

377

callback: Debug function to call

378

args: Arguments to callback

379

kwargs: Keyword arguments to callback

380

"""

381

382

def trace_function(fn):

383

"""

384

Trace function execution for debugging.

385

386

Args:

387

fn: Function to trace

388

389

Returns:

390

Traced version of function

391

"""

392

393

def memory_profiler():

394

"""Memory profiling utilities."""

395

pass

396

397

def computation_graph_visualizer():

398

"""Tools for visualizing computation graphs."""

399

pass

400

```

401

402

## Migration Patterns

403

404

When experimental features graduate to main JAX API:

405

406

```python

407

# Old experimental usage

408

from jax.experimental import feature_name

409

410

# New main API usage (after graduation)

411

from jax import feature_name

412

413

# Or sometimes moves to different module

414

from jax.some_module import feature_name

415

```

416

417

## Usage Guidelines

418

419

### Best Practices for Experimental Features

420

421

```python

422

# 1. Version pinning when using experimental features

423

# requirements.txt: jax==0.7.1 # Pin exact version

424

425

# 2. Graceful fallbacks

426

try:

427

from jax.experimental import new_feature

428

use_experimental = True

429

except ImportError:

430

use_experimental = False

431

432

def my_function(x):

433

if use_experimental:

434

return new_feature.optimized_op(x)

435

else:

436

return traditional_op(x)

437

438

# 3. Feature flags for experimental code

439

USE_EXPERIMENTAL_AMP = False

440

441

if USE_EXPERIMENTAL_AMP:

442

amp_policy = jax.experimental.amp_policy('float16')

443

train_fn = amp_policy(train_fn)

444

445

# 4. Documentation and warnings

446

def experimental_model_fn(x):

447

"""

448

Model function using experimental JAX features.

449

450

Warning: Uses jax.experimental.* APIs that may change.

451

Tested with JAX v0.7.1.

452

"""

453

# Implementation using experimental features

454

pass

455

```

456

457

### Testing Experimental Features

458

459

```python

460

import pytest

461

462

# Skip tests if experimental feature not available

463

@pytest.mark.skipif(

464

not hasattr(jax.experimental, 'new_feature'),

465

reason="Experimental feature not available"

466

)

467

def test_experimental_feature():

468

# Test experimental functionality

469

pass

470

471

# Conditional testing based on JAX version

472

import jax

473

jax_version = tuple(map(int, jax.__version__.split('.')[:2]))

474

475

@pytest.mark.skipif(

476

jax_version < (0, 7),

477

reason="Feature requires JAX >= 0.7"

478

)

479

def test_version_dependent_feature():

480

# Test version-dependent experimental feature

481

pass

482

```