Differentiate, compile, and transform Numpy code.
—
JAX's core strength lies in its composable function transformations that enable automatic differentiation, just-in-time compilation, vectorization, and parallelization. These transformations can be arbitrarily composed and applied to pure Python functions.
Compiles functions to optimized XLA code for improved performance on CPUs, GPUs, and TPUs. JIT compilation happens lazily on first call and caches compiled functions.
def jit(
fun: Callable,
in_shardings=None,
out_shardings=None,
static_argnums=None,
static_argnames=None,
donate_argnums=None,
donate_argnames=None,
keep_unused=False,
device=None,
backend=None,
inline=False,
abstracted_axes=None
) -> Callable:
"""
Just-in-time compile a function for improved performance.
Args:
fun: Function to JIT compile
in_shardings: How inputs should be sharded across devices
out_shardings: How outputs should be sharded across devices
static_argnums: Tuple of argument indices to treat as static
static_argnames: Tuple of keyword argument names to treat as static
donate_argnums: Tuple of argument indices to donate (reuse memory)
donate_argnames: Tuple of keyword argument names to donate
keep_unused: Whether to keep unused arguments in compiled function
device: Device to place computation on
backend: Backend to use for compilation
inline: Whether to inline the function
abstracted_axes: Axes to abstract for shape polymorphism
Returns:
JIT-compiled function with same signature as input
"""Usage example:
@jax.jit
def fast_computation(x, y):
return jnp.sum(x ** 2 + y ** 2)
# Or with static arguments
@jax.jit(static_argnums=(1,))
def dynamic_slice(x, size):
return x[:size]Compute gradients of scalar-valued functions using reverse-mode automatic differentiation (backpropagation).
def grad(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[int] = ()
) -> Callable:
"""
Create function that computes gradient of scalar-valued function.
Args:
fun: Function to differentiate (must return scalar)
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data (value, aux)
holomorphic: Whether function is holomorphic (complex differentiable)
allow_int: Whether to allow integer inputs
reduce_axes: Axes to reduce over when function output is not scalar
Returns:
Function that computes gradient with respect to specified arguments
"""
def value_and_grad(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[int] = ()
) -> Callable:
"""
Create function that computes both value and gradient.
Args:
fun: Function to differentiate
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Whether function is holomorphic
allow_int: Whether to allow integer inputs
reduce_axes: Axes to reduce over when function output is not scalar
Returns:
Function that returns (value, gradient) tuple
"""Usage examples:
def loss_fn(params, x, y):
predictions = params[0] * x + params[1]
return jnp.mean((predictions - y) ** 2)
# Gradient function
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x, y)
# Value and gradient together
val_grad_fn = jax.value_and_grad(loss_fn)
loss_val, grads = val_grad_fn(params, x, y)
# Gradient with respect to multiple arguments
multi_grad_fn = jax.grad(loss_fn, argnums=(0, 1, 2))
param_grads, x_grads, y_grads = multi_grad_fn(params, x, y)Compute full Jacobian matrices using forward-mode or reverse-mode differentiation.
def jacobian(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False
) -> Callable:
"""
Create function that computes Jacobian matrix.
Args:
fun: Function to compute Jacobian of
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Whether function is holomorphic
allow_int: Whether to allow integer inputs
Returns:
Function that returns Jacobian matrix
"""
def jacfwd(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False
) -> Callable:
"""
Jacobian using forward-mode AD (efficient for tall Jacobians).
Args:
fun: Function to differentiate
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Whether function is holomorphic
Returns:
Function that computes Jacobian using forward-mode AD
"""
def jacrev(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False
) -> Callable:
"""
Jacobian using reverse-mode AD (efficient for wide Jacobians).
Args:
fun: Function to differentiate
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Whether function is holomorphic
Returns:
Function that computes Jacobian using reverse-mode AD
"""
def hessian(
fun: Callable,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False
) -> Callable:
"""
Create function that computes Hessian matrix (second derivatives).
Args:
fun: Scalar-valued function to compute Hessian of
argnums: Argument number(s) to differentiate with respect to
has_aux: Whether function returns auxiliary data
holomorphic: Whether function is holomorphic
Returns:
Function that returns Hessian matrix
"""Lower-level differentiation primitives for building custom transformations.
def jvp(
fun: Callable,
primals: Sequence,
tangents: Sequence
) -> tuple:
"""
Jacobian-vector product using forward-mode AD.
Args:
fun: Function to differentiate
primals: Point at which to evaluate function
tangents: Tangent vectors to multiply Jacobian by
Returns:
Tuple of (primals_out, tangents_out)
"""
def vjp(
fun: Callable,
*primals
) -> tuple:
"""
Vector-Jacobian product using reverse-mode AD.
Args:
fun: Function to differentiate
primals: Point at which to evaluate function
Returns:
Tuple of (primals_out, vjp_fun) where vjp_fun computes VJP
"""
def linearize(fun: Callable, *primals) -> tuple:
"""
Linearize function around given point.
Args:
fun: Function to linearize
primals: Point to linearize around
Returns:
Tuple of (primals_out, jvp_fun) for computing JVPs
"""Transform functions to work on batches of inputs by adding a batch dimension and vectorizing over it.
def vmap(
fun: Callable,
in_axes=0,
out_axes=0,
axis_name=None,
axis_size=None,
spmd_axis_name=None
) -> Callable:
"""
Vectorizing map that adds batch dimension to function.
Args:
fun: Function to vectorize
in_axes: How to map over input arguments (int, None, or tuple)
out_axes: How to map over output values (int, None, or tuple)
axis_name: Name for the mapped axis (for use with psum etc.)
axis_size: Size of mapped axis (for use with axis_name)
spmd_axis_name: SPMD axis name for multi-device computation
Returns:
Vectorized function that works on batches
"""Usage examples:
# Vectorize over first axis of both inputs
batch_fn = jax.vmap(single_example_fn)
batch_outputs = batch_fn(batch_inputs)
# Vectorize with different input axes
# x has batch dim 0, y has batch dim 1
fn = jax.vmap(process_fn, in_axes=(0, 1))
# Vectorize with no batch dim for some inputs
# x has batch dim 0, y is broadcast to all batch elements
fn = jax.vmap(process_fn, in_axes=(0, None))Distribute computation across multiple devices using SPMD (Single Program, Multiple Data) parallelism.
def pmap(
fun: Callable,
axis_name=None,
in_axes=0,
out_axes=0,
static_broadcasted_argnums=(),
devices=None,
backend=None,
axis_size=None,
donate_argnums=(),
global_arg_shapes=None
) -> Callable:
"""
Parallel map that distributes computation across multiple devices.
Args:
fun: Function to parallelize
axis_name: Name for the parallel axis
in_axes: How to split inputs across devices
out_axes: How to collect outputs from devices
static_broadcasted_argnums: Arguments to broadcast to all devices
devices: Explicit device placement
backend: Backend to use
axis_size: Size of parallel axis
donate_argnums: Arguments to donate memory
global_arg_shapes: Global shapes for arguments
Returns:
Function that runs in parallel across devices
"""Usage example:
# Function runs on each device with its slice of data
parallel_fn = jax.pmap(single_device_fn)
# Input shape: (num_devices, per_device_batch_size, ...)
outputs = parallel_fn(distributed_inputs)Trade computation for memory using gradient checkpointing (rematerialization).
def checkpoint(
fun: Callable,
*,
concrete: bool = False,
policy: Callable = None,
prevent_cse: bool = True,
static_argnums: int | Sequence[int] = ()
) -> Callable:
"""
Gradient checkpointing for memory-efficient backpropagation.
Args:
fun: Function to apply checkpointing to
concrete: Whether to use concrete checkpointing
policy: Policy for deciding what to checkpoint
prevent_cse: Whether to prevent common subexpression elimination
static_argnums: Arguments to treat as static
Returns:
Checkpointed function that saves memory during backward pass
"""
# Alias for checkpoint
remat = checkpointUsage example:
@jax.checkpoint
def expensive_layer(x, params):
# Expensive computation that will be recomputed during backprop
return jnp.tanh(x @ params)
# Use in gradient computation to save memory
grad_fn = jax.grad(lambda params: loss(checkpoint_layer(x, params)))Define custom forward and backward passes for functions.
def custom_gradient(fun: Callable) -> Callable:
"""
Decorator to define custom gradient for function.
The decorated function should return (primal_out, grad_fn) where
grad_fn(cotangents) -> tangents.
Args:
fun: Function with custom gradient implementation
Returns:
Function with custom gradient behavior
"""
def custom_jvp(fun: Callable) -> Callable:
"""
Decorator to define custom JVP (forward-mode derivative) rule.
Args:
fun: Function to define custom JVP for
Returns:
Function with custom JVP behavior
"""
def custom_vjp(fun: Callable) -> Callable:
"""
Decorator to define custom VJP (reverse-mode derivative) rule.
Args:
fun: Function to define custom VJP for
Returns:
Function with custom VJP behavior
"""Additional differentiation utilities and transformations.
def stop_gradient(x) -> Array:
"""
Stop gradient computation at this point.
Args:
x: Array to stop gradient for
Returns:
Array with gradient flow stopped
"""
def fwd_and_bwd(
fun: Callable,
*primals,
**kwargs
) -> tuple:
"""
Compute forward and backward passes separately.
Args:
fun: Function to compute forward/backward for
primals: Input values
Returns:
Tuple of (primal_out, vjp_fun)
"""
def closure_convert(
fun: Callable,
*closed_over_vals
) -> tuple:
"""
Convert function with closure variables for differentiation.
Args:
fun: Function with closure variables
closed_over_vals: Values closed over by function
Returns:
Converted function and closure values
"""
def pure_callback(
callback: Callable,
result_shape_dtypes,
*args,
sharding=None,
vmap_method=None,
**kwargs
) -> Any:
"""
Call host function with pure side effects from JAX computation.
Args:
callback: Pure host function to call
result_shape_dtypes: Shape and dtype of callback result
args: Arguments to pass to callback
sharding: Sharding specification for result
vmap_method: How to handle vectorization
kwargs: Additional keyword arguments
Returns:
Result of callback with specified shape and dtype
"""
def effects_barrier() -> None:
"""
Create synchronization barrier for side effects.
Ensures all preceding computations with side effects complete
before continuing with subsequent computations.
"""
def named_call(f: Callable, *, name: str) -> Callable:
"""
Wrap function with a name for debugging and profiling.
Args:
f: Function to wrap
name: Name to associate with function calls
Returns:
Wrapped function that appears with given name in traces
"""
def named_scope(name: str):
"""
Context manager for named scopes in JAX computations.
Args:
name: Name for the computation scope
Usage:
with jax.named_scope("layer1"):
output = layer_computation(input)
"""JAX transformations can be arbitrarily composed for powerful effects:
# JIT-compiled gradient
fast_grad = jax.jit(jax.grad(loss_fn))
# Vectorized gradient (per-example gradients)
batch_grad = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))
# Parallel gradient computation
parallel_grad = jax.pmap(jax.grad(loss_fn))
# Second derivatives (Hessian-vector product)
hvp = lambda v: jax.jvp(jax.grad(loss_fn), (params,), (v,))[1]
# Gradient of gradient (for meta-learning)
meta_grad = jax.grad(lambda meta_params: loss_fn(update_fn(meta_params)))Install with Tessl CLI
npx tessl i tessl/pypi-jax