or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

core-transformations.mddevice-memory.mdexperimental.mdindex.mdlow-level-ops.mdneural-networks.mdnumpy-compatibility.mdrandom-numbers.mdscipy-compatibility.mdtree-operations.md

core-transformations.mddocs/

0

# Core Program Transformations

1

2

JAX's core strength lies in its composable function transformations that enable automatic differentiation, just-in-time compilation, vectorization, and parallelization. These transformations can be arbitrarily composed and applied to pure Python functions.

3

4

## Capabilities

5

6

### Just-in-Time Compilation

7

8

Compiles functions to optimized XLA code for improved performance on CPUs, GPUs, and TPUs. JIT compilation happens lazily on first call and caches compiled functions.

9

10

```python { .api }

11

def jit(

12

fun: Callable,

13

in_shardings=None,

14

out_shardings=None,

15

static_argnums=None,

16

static_argnames=None,

17

donate_argnums=None,

18

donate_argnames=None,

19

keep_unused=False,

20

device=None,

21

backend=None,

22

inline=False,

23

abstracted_axes=None

24

) -> Callable:

25

"""

26

Just-in-time compile a function for improved performance.

27

28

Args:

29

fun: Function to JIT compile

30

in_shardings: How inputs should be sharded across devices

31

out_shardings: How outputs should be sharded across devices

32

static_argnums: Tuple of argument indices to treat as static

33

static_argnames: Tuple of keyword argument names to treat as static

34

donate_argnums: Tuple of argument indices to donate (reuse memory)

35

donate_argnames: Tuple of keyword argument names to donate

36

keep_unused: Whether to keep unused arguments in compiled function

37

device: Device to place computation on

38

backend: Backend to use for compilation

39

inline: Whether to inline the function

40

abstracted_axes: Axes to abstract for shape polymorphism

41

42

Returns:

43

JIT-compiled function with same signature as input

44

"""

45

```

46

47

Usage example:

48

```python

49

@jax.jit

50

def fast_computation(x, y):

51

return jnp.sum(x ** 2 + y ** 2)

52

53

# Or with static arguments

54

@jax.jit(static_argnums=(1,))

55

def dynamic_slice(x, size):

56

return x[:size]

57

```

58

59

### Automatic Differentiation

60

61

Compute gradients of scalar-valued functions using reverse-mode automatic differentiation (backpropagation).

62

63

```python { .api }

64

def grad(

65

fun: Callable,

66

argnums: int | Sequence[int] = 0,

67

has_aux: bool = False,

68

holomorphic: bool = False,

69

allow_int: bool = False,

70

reduce_axes: Sequence[int] = ()

71

) -> Callable:

72

"""

73

Create function that computes gradient of scalar-valued function.

74

75

Args:

76

fun: Function to differentiate (must return scalar)

77

argnums: Argument number(s) to differentiate with respect to

78

has_aux: Whether function returns auxiliary data (value, aux)

79

holomorphic: Whether function is holomorphic (complex differentiable)

80

allow_int: Whether to allow integer inputs

81

reduce_axes: Axes to reduce over when function output is not scalar

82

83

Returns:

84

Function that computes gradient with respect to specified arguments

85

"""

86

87

def value_and_grad(

88

fun: Callable,

89

argnums: int | Sequence[int] = 0,

90

has_aux: bool = False,

91

holomorphic: bool = False,

92

allow_int: bool = False,

93

reduce_axes: Sequence[int] = ()

94

) -> Callable:

95

"""

96

Create function that computes both value and gradient.

97

98

Args:

99

fun: Function to differentiate

100

argnums: Argument number(s) to differentiate with respect to

101

has_aux: Whether function returns auxiliary data

102

holomorphic: Whether function is holomorphic

103

allow_int: Whether to allow integer inputs

104

reduce_axes: Axes to reduce over when function output is not scalar

105

106

Returns:

107

Function that returns (value, gradient) tuple

108

"""

109

```

110

111

Usage examples:

112

```python

113

def loss_fn(params, x, y):

114

predictions = params[0] * x + params[1]

115

return jnp.mean((predictions - y) ** 2)

116

117

# Gradient function

118

grad_fn = jax.grad(loss_fn)

119

grads = grad_fn(params, x, y)

120

121

# Value and gradient together

122

val_grad_fn = jax.value_and_grad(loss_fn)

123

loss_val, grads = val_grad_fn(params, x, y)

124

125

# Gradient with respect to multiple arguments

126

multi_grad_fn = jax.grad(loss_fn, argnums=(0, 1, 2))

127

param_grads, x_grads, y_grads = multi_grad_fn(params, x, y)

128

```

129

130

### Jacobian Computation

131

132

Compute full Jacobian matrices using forward-mode or reverse-mode differentiation.

133

134

```python { .api }

135

def jacobian(

136

fun: Callable,

137

argnums: int | Sequence[int] = 0,

138

has_aux: bool = False,

139

holomorphic: bool = False,

140

allow_int: bool = False

141

) -> Callable:

142

"""

143

Create function that computes Jacobian matrix.

144

145

Args:

146

fun: Function to compute Jacobian of

147

argnums: Argument number(s) to differentiate with respect to

148

has_aux: Whether function returns auxiliary data

149

holomorphic: Whether function is holomorphic

150

allow_int: Whether to allow integer inputs

151

152

Returns:

153

Function that returns Jacobian matrix

154

"""

155

156

def jacfwd(

157

fun: Callable,

158

argnums: int | Sequence[int] = 0,

159

has_aux: bool = False,

160

holomorphic: bool = False

161

) -> Callable:

162

"""

163

Jacobian using forward-mode AD (efficient for tall Jacobians).

164

165

Args:

166

fun: Function to differentiate

167

argnums: Argument number(s) to differentiate with respect to

168

has_aux: Whether function returns auxiliary data

169

holomorphic: Whether function is holomorphic

170

171

Returns:

172

Function that computes Jacobian using forward-mode AD

173

"""

174

175

def jacrev(

176

fun: Callable,

177

argnums: int | Sequence[int] = 0,

178

has_aux: bool = False,

179

holomorphic: bool = False

180

) -> Callable:

181

"""

182

Jacobian using reverse-mode AD (efficient for wide Jacobians).

183

184

Args:

185

fun: Function to differentiate

186

argnums: Argument number(s) to differentiate with respect to

187

has_aux: Whether function returns auxiliary data

188

holomorphic: Whether function is holomorphic

189

190

Returns:

191

Function that computes Jacobian using reverse-mode AD

192

"""

193

194

def hessian(

195

fun: Callable,

196

argnums: int | Sequence[int] = 0,

197

has_aux: bool = False,

198

holomorphic: bool = False

199

) -> Callable:

200

"""

201

Create function that computes Hessian matrix (second derivatives).

202

203

Args:

204

fun: Scalar-valued function to compute Hessian of

205

argnums: Argument number(s) to differentiate with respect to

206

has_aux: Whether function returns auxiliary data

207

holomorphic: Whether function is holomorphic

208

209

Returns:

210

Function that returns Hessian matrix

211

"""

212

```

213

214

### Forward and Reverse Mode Primitives

215

216

Lower-level differentiation primitives for building custom transformations.

217

218

```python { .api }

219

def jvp(

220

fun: Callable,

221

primals: Sequence,

222

tangents: Sequence

223

) -> tuple:

224

"""

225

Jacobian-vector product using forward-mode AD.

226

227

Args:

228

fun: Function to differentiate

229

primals: Point at which to evaluate function

230

tangents: Tangent vectors to multiply Jacobian by

231

232

Returns:

233

Tuple of (primals_out, tangents_out)

234

"""

235

236

def vjp(

237

fun: Callable,

238

*primals

239

) -> tuple:

240

"""

241

Vector-Jacobian product using reverse-mode AD.

242

243

Args:

244

fun: Function to differentiate

245

primals: Point at which to evaluate function

246

247

Returns:

248

Tuple of (primals_out, vjp_fun) where vjp_fun computes VJP

249

"""

250

251

def linearize(fun: Callable, *primals) -> tuple:

252

"""

253

Linearize function around given point.

254

255

Args:

256

fun: Function to linearize

257

primals: Point to linearize around

258

259

Returns:

260

Tuple of (primals_out, jvp_fun) for computing JVPs

261

"""

262

```

263

264

### Vectorization

265

266

Transform functions to work on batches of inputs by adding a batch dimension and vectorizing over it.

267

268

```python { .api }

269

def vmap(

270

fun: Callable,

271

in_axes=0,

272

out_axes=0,

273

axis_name=None,

274

axis_size=None,

275

spmd_axis_name=None

276

) -> Callable:

277

"""

278

Vectorizing map that adds batch dimension to function.

279

280

Args:

281

fun: Function to vectorize

282

in_axes: How to map over input arguments (int, None, or tuple)

283

out_axes: How to map over output values (int, None, or tuple)

284

axis_name: Name for the mapped axis (for use with psum etc.)

285

axis_size: Size of mapped axis (for use with axis_name)

286

spmd_axis_name: SPMD axis name for multi-device computation

287

288

Returns:

289

Vectorized function that works on batches

290

"""

291

```

292

293

Usage examples:

294

```python

295

# Vectorize over first axis of both inputs

296

batch_fn = jax.vmap(single_example_fn)

297

batch_outputs = batch_fn(batch_inputs)

298

299

# Vectorize with different input axes

300

# x has batch dim 0, y has batch dim 1

301

fn = jax.vmap(process_fn, in_axes=(0, 1))

302

303

# Vectorize with no batch dim for some inputs

304

# x has batch dim 0, y is broadcast to all batch elements

305

fn = jax.vmap(process_fn, in_axes=(0, None))

306

```

307

308

### Parallelization

309

310

Distribute computation across multiple devices using SPMD (Single Program, Multiple Data) parallelism.

311

312

```python { .api }

313

def pmap(

314

fun: Callable,

315

axis_name=None,

316

in_axes=0,

317

out_axes=0,

318

static_broadcasted_argnums=(),

319

devices=None,

320

backend=None,

321

axis_size=None,

322

donate_argnums=(),

323

global_arg_shapes=None

324

) -> Callable:

325

"""

326

Parallel map that distributes computation across multiple devices.

327

328

Args:

329

fun: Function to parallelize

330

axis_name: Name for the parallel axis

331

in_axes: How to split inputs across devices

332

out_axes: How to collect outputs from devices

333

static_broadcasted_argnums: Arguments to broadcast to all devices

334

devices: Explicit device placement

335

backend: Backend to use

336

axis_size: Size of parallel axis

337

donate_argnums: Arguments to donate memory

338

global_arg_shapes: Global shapes for arguments

339

340

Returns:

341

Function that runs in parallel across devices

342

"""

343

```

344

345

Usage example:

346

```python

347

# Function runs on each device with its slice of data

348

parallel_fn = jax.pmap(single_device_fn)

349

# Input shape: (num_devices, per_device_batch_size, ...)

350

outputs = parallel_fn(distributed_inputs)

351

```

352

353

### Memory-Efficient Gradient Computation

354

355

Trade computation for memory using gradient checkpointing (rematerialization).

356

357

```python { .api }

358

def checkpoint(

359

fun: Callable,

360

*,

361

concrete: bool = False,

362

policy: Callable = None,

363

prevent_cse: bool = True,

364

static_argnums: int | Sequence[int] = ()

365

) -> Callable:

366

"""

367

Gradient checkpointing for memory-efficient backpropagation.

368

369

Args:

370

fun: Function to apply checkpointing to

371

concrete: Whether to use concrete checkpointing

372

policy: Policy for deciding what to checkpoint

373

prevent_cse: Whether to prevent common subexpression elimination

374

static_argnums: Arguments to treat as static

375

376

Returns:

377

Checkpointed function that saves memory during backward pass

378

"""

379

380

# Alias for checkpoint

381

remat = checkpoint

382

```

383

384

Usage example:

385

```python

386

@jax.checkpoint

387

def expensive_layer(x, params):

388

# Expensive computation that will be recomputed during backprop

389

return jnp.tanh(x @ params)

390

391

# Use in gradient computation to save memory

392

grad_fn = jax.grad(lambda params: loss(checkpoint_layer(x, params)))

393

```

394

395

### Custom Derivatives

396

397

Define custom forward and backward passes for functions.

398

399

```python { .api }

400

def custom_gradient(fun: Callable) -> Callable:

401

"""

402

Decorator to define custom gradient for function.

403

404

The decorated function should return (primal_out, grad_fn) where

405

grad_fn(cotangents) -> tangents.

406

407

Args:

408

fun: Function with custom gradient implementation

409

410

Returns:

411

Function with custom gradient behavior

412

"""

413

414

def custom_jvp(fun: Callable) -> Callable:

415

"""

416

Decorator to define custom JVP (forward-mode derivative) rule.

417

418

Args:

419

fun: Function to define custom JVP for

420

421

Returns:

422

Function with custom JVP behavior

423

"""

424

425

def custom_vjp(fun: Callable) -> Callable:

426

"""

427

Decorator to define custom VJP (reverse-mode derivative) rule.

428

429

Args:

430

fun: Function to define custom VJP for

431

432

Returns:

433

Function with custom VJP behavior

434

"""

435

```

436

437

### Advanced Differentiation

438

439

Additional differentiation utilities and transformations.

440

441

```python { .api }

442

def stop_gradient(x) -> Array:

443

"""

444

Stop gradient computation at this point.

445

446

Args:

447

x: Array to stop gradient for

448

449

Returns:

450

Array with gradient flow stopped

451

"""

452

453

def fwd_and_bwd(

454

fun: Callable,

455

*primals,

456

**kwargs

457

) -> tuple:

458

"""

459

Compute forward and backward passes separately.

460

461

Args:

462

fun: Function to compute forward/backward for

463

primals: Input values

464

465

Returns:

466

Tuple of (primal_out, vjp_fun)

467

"""

468

469

def closure_convert(

470

fun: Callable,

471

*closed_over_vals

472

) -> tuple:

473

"""

474

Convert function with closure variables for differentiation.

475

476

Args:

477

fun: Function with closure variables

478

closed_over_vals: Values closed over by function

479

480

Returns:

481

Converted function and closure values

482

"""

483

484

def pure_callback(

485

callback: Callable,

486

result_shape_dtypes,

487

*args,

488

sharding=None,

489

vmap_method=None,

490

**kwargs

491

) -> Any:

492

"""

493

Call host function with pure side effects from JAX computation.

494

495

Args:

496

callback: Pure host function to call

497

result_shape_dtypes: Shape and dtype of callback result

498

args: Arguments to pass to callback

499

sharding: Sharding specification for result

500

vmap_method: How to handle vectorization

501

kwargs: Additional keyword arguments

502

503

Returns:

504

Result of callback with specified shape and dtype

505

"""

506

507

def effects_barrier() -> None:

508

"""

509

Create synchronization barrier for side effects.

510

511

Ensures all preceding computations with side effects complete

512

before continuing with subsequent computations.

513

"""

514

515

def named_call(f: Callable, *, name: str) -> Callable:

516

"""

517

Wrap function with a name for debugging and profiling.

518

519

Args:

520

f: Function to wrap

521

name: Name to associate with function calls

522

523

Returns:

524

Wrapped function that appears with given name in traces

525

"""

526

527

def named_scope(name: str):

528

"""

529

Context manager for named scopes in JAX computations.

530

531

Args:

532

name: Name for the computation scope

533

534

Usage:

535

with jax.named_scope("layer1"):

536

output = layer_computation(input)

537

"""

538

```

539

540

## Transformation Composition

541

542

JAX transformations can be arbitrarily composed for powerful effects:

543

544

```python

545

# JIT-compiled gradient

546

fast_grad = jax.jit(jax.grad(loss_fn))

547

548

# Vectorized gradient (per-example gradients)

549

batch_grad = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

550

551

# Parallel gradient computation

552

parallel_grad = jax.pmap(jax.grad(loss_fn))

553

554

# Second derivatives (Hessian-vector product)

555

hvp = lambda v: jax.jvp(jax.grad(loss_fn), (params,), (v,))[1]

556

557

# Gradient of gradient (for meta-learning)

558

meta_grad = jax.grad(lambda meta_params: loss_fn(update_fn(meta_params)))

559

```