0
# Core Program Transformations
1
2
JAX's core strength lies in its composable function transformations that enable automatic differentiation, just-in-time compilation, vectorization, and parallelization. These transformations can be arbitrarily composed and applied to pure Python functions.
3
4
## Capabilities
5
6
### Just-in-Time Compilation
7
8
Compiles functions to optimized XLA code for improved performance on CPUs, GPUs, and TPUs. JIT compilation happens lazily on first call and caches compiled functions.
9
10
```python { .api }
11
def jit(
12
fun: Callable,
13
in_shardings=None,
14
out_shardings=None,
15
static_argnums=None,
16
static_argnames=None,
17
donate_argnums=None,
18
donate_argnames=None,
19
keep_unused=False,
20
device=None,
21
backend=None,
22
inline=False,
23
abstracted_axes=None
24
) -> Callable:
25
"""
26
Just-in-time compile a function for improved performance.
27
28
Args:
29
fun: Function to JIT compile
30
in_shardings: How inputs should be sharded across devices
31
out_shardings: How outputs should be sharded across devices
32
static_argnums: Tuple of argument indices to treat as static
33
static_argnames: Tuple of keyword argument names to treat as static
34
donate_argnums: Tuple of argument indices to donate (reuse memory)
35
donate_argnames: Tuple of keyword argument names to donate
36
keep_unused: Whether to keep unused arguments in compiled function
37
device: Device to place computation on
38
backend: Backend to use for compilation
39
inline: Whether to inline the function
40
abstracted_axes: Axes to abstract for shape polymorphism
41
42
Returns:
43
JIT-compiled function with same signature as input
44
"""
45
```
46
47
Usage example:
48
```python
49
@jax.jit
50
def fast_computation(x, y):
51
return jnp.sum(x ** 2 + y ** 2)
52
53
# Or with static arguments
54
@jax.jit(static_argnums=(1,))
55
def dynamic_slice(x, size):
56
return x[:size]
57
```
58
59
### Automatic Differentiation
60
61
Compute gradients of scalar-valued functions using reverse-mode automatic differentiation (backpropagation).
62
63
```python { .api }
64
def grad(
65
fun: Callable,
66
argnums: int | Sequence[int] = 0,
67
has_aux: bool = False,
68
holomorphic: bool = False,
69
allow_int: bool = False,
70
reduce_axes: Sequence[int] = ()
71
) -> Callable:
72
"""
73
Create function that computes gradient of scalar-valued function.
74
75
Args:
76
fun: Function to differentiate (must return scalar)
77
argnums: Argument number(s) to differentiate with respect to
78
has_aux: Whether function returns auxiliary data (value, aux)
79
holomorphic: Whether function is holomorphic (complex differentiable)
80
allow_int: Whether to allow integer inputs
81
reduce_axes: Axes to reduce over when function output is not scalar
82
83
Returns:
84
Function that computes gradient with respect to specified arguments
85
"""
86
87
def value_and_grad(
88
fun: Callable,
89
argnums: int | Sequence[int] = 0,
90
has_aux: bool = False,
91
holomorphic: bool = False,
92
allow_int: bool = False,
93
reduce_axes: Sequence[int] = ()
94
) -> Callable:
95
"""
96
Create function that computes both value and gradient.
97
98
Args:
99
fun: Function to differentiate
100
argnums: Argument number(s) to differentiate with respect to
101
has_aux: Whether function returns auxiliary data
102
holomorphic: Whether function is holomorphic
103
allow_int: Whether to allow integer inputs
104
reduce_axes: Axes to reduce over when function output is not scalar
105
106
Returns:
107
Function that returns (value, gradient) tuple
108
"""
109
```
110
111
Usage examples:
112
```python
113
def loss_fn(params, x, y):
114
predictions = params[0] * x + params[1]
115
return jnp.mean((predictions - y) ** 2)
116
117
# Gradient function
118
grad_fn = jax.grad(loss_fn)
119
grads = grad_fn(params, x, y)
120
121
# Value and gradient together
122
val_grad_fn = jax.value_and_grad(loss_fn)
123
loss_val, grads = val_grad_fn(params, x, y)
124
125
# Gradient with respect to multiple arguments
126
multi_grad_fn = jax.grad(loss_fn, argnums=(0, 1, 2))
127
param_grads, x_grads, y_grads = multi_grad_fn(params, x, y)
128
```
129
130
### Jacobian Computation
131
132
Compute full Jacobian matrices using forward-mode or reverse-mode differentiation.
133
134
```python { .api }
135
def jacobian(
136
fun: Callable,
137
argnums: int | Sequence[int] = 0,
138
has_aux: bool = False,
139
holomorphic: bool = False,
140
allow_int: bool = False
141
) -> Callable:
142
"""
143
Create function that computes Jacobian matrix.
144
145
Args:
146
fun: Function to compute Jacobian of
147
argnums: Argument number(s) to differentiate with respect to
148
has_aux: Whether function returns auxiliary data
149
holomorphic: Whether function is holomorphic
150
allow_int: Whether to allow integer inputs
151
152
Returns:
153
Function that returns Jacobian matrix
154
"""
155
156
def jacfwd(
157
fun: Callable,
158
argnums: int | Sequence[int] = 0,
159
has_aux: bool = False,
160
holomorphic: bool = False
161
) -> Callable:
162
"""
163
Jacobian using forward-mode AD (efficient for tall Jacobians).
164
165
Args:
166
fun: Function to differentiate
167
argnums: Argument number(s) to differentiate with respect to
168
has_aux: Whether function returns auxiliary data
169
holomorphic: Whether function is holomorphic
170
171
Returns:
172
Function that computes Jacobian using forward-mode AD
173
"""
174
175
def jacrev(
176
fun: Callable,
177
argnums: int | Sequence[int] = 0,
178
has_aux: bool = False,
179
holomorphic: bool = False
180
) -> Callable:
181
"""
182
Jacobian using reverse-mode AD (efficient for wide Jacobians).
183
184
Args:
185
fun: Function to differentiate
186
argnums: Argument number(s) to differentiate with respect to
187
has_aux: Whether function returns auxiliary data
188
holomorphic: Whether function is holomorphic
189
190
Returns:
191
Function that computes Jacobian using reverse-mode AD
192
"""
193
194
def hessian(
195
fun: Callable,
196
argnums: int | Sequence[int] = 0,
197
has_aux: bool = False,
198
holomorphic: bool = False
199
) -> Callable:
200
"""
201
Create function that computes Hessian matrix (second derivatives).
202
203
Args:
204
fun: Scalar-valued function to compute Hessian of
205
argnums: Argument number(s) to differentiate with respect to
206
has_aux: Whether function returns auxiliary data
207
holomorphic: Whether function is holomorphic
208
209
Returns:
210
Function that returns Hessian matrix
211
"""
212
```
213
214
### Forward and Reverse Mode Primitives
215
216
Lower-level differentiation primitives for building custom transformations.
217
218
```python { .api }
219
def jvp(
220
fun: Callable,
221
primals: Sequence,
222
tangents: Sequence
223
) -> tuple:
224
"""
225
Jacobian-vector product using forward-mode AD.
226
227
Args:
228
fun: Function to differentiate
229
primals: Point at which to evaluate function
230
tangents: Tangent vectors to multiply Jacobian by
231
232
Returns:
233
Tuple of (primals_out, tangents_out)
234
"""
235
236
def vjp(
237
fun: Callable,
238
*primals
239
) -> tuple:
240
"""
241
Vector-Jacobian product using reverse-mode AD.
242
243
Args:
244
fun: Function to differentiate
245
primals: Point at which to evaluate function
246
247
Returns:
248
Tuple of (primals_out, vjp_fun) where vjp_fun computes VJP
249
"""
250
251
def linearize(fun: Callable, *primals) -> tuple:
252
"""
253
Linearize function around given point.
254
255
Args:
256
fun: Function to linearize
257
primals: Point to linearize around
258
259
Returns:
260
Tuple of (primals_out, jvp_fun) for computing JVPs
261
"""
262
```
263
264
### Vectorization
265
266
Transform functions to work on batches of inputs by adding a batch dimension and vectorizing over it.
267
268
```python { .api }
269
def vmap(
270
fun: Callable,
271
in_axes=0,
272
out_axes=0,
273
axis_name=None,
274
axis_size=None,
275
spmd_axis_name=None
276
) -> Callable:
277
"""
278
Vectorizing map that adds batch dimension to function.
279
280
Args:
281
fun: Function to vectorize
282
in_axes: How to map over input arguments (int, None, or tuple)
283
out_axes: How to map over output values (int, None, or tuple)
284
axis_name: Name for the mapped axis (for use with psum etc.)
285
axis_size: Size of mapped axis (for use with axis_name)
286
spmd_axis_name: SPMD axis name for multi-device computation
287
288
Returns:
289
Vectorized function that works on batches
290
"""
291
```
292
293
Usage examples:
294
```python
295
# Vectorize over first axis of both inputs
296
batch_fn = jax.vmap(single_example_fn)
297
batch_outputs = batch_fn(batch_inputs)
298
299
# Vectorize with different input axes
300
# x has batch dim 0, y has batch dim 1
301
fn = jax.vmap(process_fn, in_axes=(0, 1))
302
303
# Vectorize with no batch dim for some inputs
304
# x has batch dim 0, y is broadcast to all batch elements
305
fn = jax.vmap(process_fn, in_axes=(0, None))
306
```
307
308
### Parallelization
309
310
Distribute computation across multiple devices using SPMD (Single Program, Multiple Data) parallelism.
311
312
```python { .api }
313
def pmap(
314
fun: Callable,
315
axis_name=None,
316
in_axes=0,
317
out_axes=0,
318
static_broadcasted_argnums=(),
319
devices=None,
320
backend=None,
321
axis_size=None,
322
donate_argnums=(),
323
global_arg_shapes=None
324
) -> Callable:
325
"""
326
Parallel map that distributes computation across multiple devices.
327
328
Args:
329
fun: Function to parallelize
330
axis_name: Name for the parallel axis
331
in_axes: How to split inputs across devices
332
out_axes: How to collect outputs from devices
333
static_broadcasted_argnums: Arguments to broadcast to all devices
334
devices: Explicit device placement
335
backend: Backend to use
336
axis_size: Size of parallel axis
337
donate_argnums: Arguments to donate memory
338
global_arg_shapes: Global shapes for arguments
339
340
Returns:
341
Function that runs in parallel across devices
342
"""
343
```
344
345
Usage example:
346
```python
347
# Function runs on each device with its slice of data
348
parallel_fn = jax.pmap(single_device_fn)
349
# Input shape: (num_devices, per_device_batch_size, ...)
350
outputs = parallel_fn(distributed_inputs)
351
```
352
353
### Memory-Efficient Gradient Computation
354
355
Trade computation for memory using gradient checkpointing (rematerialization).
356
357
```python { .api }
358
def checkpoint(
359
fun: Callable,
360
*,
361
concrete: bool = False,
362
policy: Callable = None,
363
prevent_cse: bool = True,
364
static_argnums: int | Sequence[int] = ()
365
) -> Callable:
366
"""
367
Gradient checkpointing for memory-efficient backpropagation.
368
369
Args:
370
fun: Function to apply checkpointing to
371
concrete: Whether to use concrete checkpointing
372
policy: Policy for deciding what to checkpoint
373
prevent_cse: Whether to prevent common subexpression elimination
374
static_argnums: Arguments to treat as static
375
376
Returns:
377
Checkpointed function that saves memory during backward pass
378
"""
379
380
# Alias for checkpoint
381
remat = checkpoint
382
```
383
384
Usage example:
385
```python
386
@jax.checkpoint
387
def expensive_layer(x, params):
388
# Expensive computation that will be recomputed during backprop
389
return jnp.tanh(x @ params)
390
391
# Use in gradient computation to save memory
392
grad_fn = jax.grad(lambda params: loss(checkpoint_layer(x, params)))
393
```
394
395
### Custom Derivatives
396
397
Define custom forward and backward passes for functions.
398
399
```python { .api }
400
def custom_gradient(fun: Callable) -> Callable:
401
"""
402
Decorator to define custom gradient for function.
403
404
The decorated function should return (primal_out, grad_fn) where
405
grad_fn(cotangents) -> tangents.
406
407
Args:
408
fun: Function with custom gradient implementation
409
410
Returns:
411
Function with custom gradient behavior
412
"""
413
414
def custom_jvp(fun: Callable) -> Callable:
415
"""
416
Decorator to define custom JVP (forward-mode derivative) rule.
417
418
Args:
419
fun: Function to define custom JVP for
420
421
Returns:
422
Function with custom JVP behavior
423
"""
424
425
def custom_vjp(fun: Callable) -> Callable:
426
"""
427
Decorator to define custom VJP (reverse-mode derivative) rule.
428
429
Args:
430
fun: Function to define custom VJP for
431
432
Returns:
433
Function with custom VJP behavior
434
"""
435
```
436
437
### Advanced Differentiation
438
439
Additional differentiation utilities and transformations.
440
441
```python { .api }
442
def stop_gradient(x) -> Array:
443
"""
444
Stop gradient computation at this point.
445
446
Args:
447
x: Array to stop gradient for
448
449
Returns:
450
Array with gradient flow stopped
451
"""
452
453
def fwd_and_bwd(
454
fun: Callable,
455
*primals,
456
**kwargs
457
) -> tuple:
458
"""
459
Compute forward and backward passes separately.
460
461
Args:
462
fun: Function to compute forward/backward for
463
primals: Input values
464
465
Returns:
466
Tuple of (primal_out, vjp_fun)
467
"""
468
469
def closure_convert(
470
fun: Callable,
471
*closed_over_vals
472
) -> tuple:
473
"""
474
Convert function with closure variables for differentiation.
475
476
Args:
477
fun: Function with closure variables
478
closed_over_vals: Values closed over by function
479
480
Returns:
481
Converted function and closure values
482
"""
483
484
def pure_callback(
485
callback: Callable,
486
result_shape_dtypes,
487
*args,
488
sharding=None,
489
vmap_method=None,
490
**kwargs
491
) -> Any:
492
"""
493
Call host function with pure side effects from JAX computation.
494
495
Args:
496
callback: Pure host function to call
497
result_shape_dtypes: Shape and dtype of callback result
498
args: Arguments to pass to callback
499
sharding: Sharding specification for result
500
vmap_method: How to handle vectorization
501
kwargs: Additional keyword arguments
502
503
Returns:
504
Result of callback with specified shape and dtype
505
"""
506
507
def effects_barrier() -> None:
508
"""
509
Create synchronization barrier for side effects.
510
511
Ensures all preceding computations with side effects complete
512
before continuing with subsequent computations.
513
"""
514
515
def named_call(f: Callable, *, name: str) -> Callable:
516
"""
517
Wrap function with a name for debugging and profiling.
518
519
Args:
520
f: Function to wrap
521
name: Name to associate with function calls
522
523
Returns:
524
Wrapped function that appears with given name in traces
525
"""
526
527
def named_scope(name: str):
528
"""
529
Context manager for named scopes in JAX computations.
530
531
Args:
532
name: Name for the computation scope
533
534
Usage:
535
with jax.named_scope("layer1"):
536
output = layer_computation(input)
537
"""
538
```
539
540
## Transformation Composition
541
542
JAX transformations can be arbitrarily composed for powerful effects:
543
544
```python
545
# JIT-compiled gradient
546
fast_grad = jax.jit(jax.grad(loss_fn))
547
548
# Vectorized gradient (per-example gradients)
549
batch_grad = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
550
551
# Parallel gradient computation
552
parallel_grad = jax.pmap(jax.grad(loss_fn))
553
554
# Second derivatives (Hessian-vector product)
555
hvp = lambda v: jax.jvp(jax.grad(loss_fn), (params,), (v,))[1]
556
557
# Gradient of gradient (for meta-learning)
558
meta_grad = jax.grad(lambda meta_params: loss_fn(update_fn(meta_params)))
559
```