CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

core-transformations.mddocs/

Core Program Transformations

JAX's core strength lies in its composable function transformations that enable automatic differentiation, just-in-time compilation, vectorization, and parallelization. These transformations can be arbitrarily composed and applied to pure Python functions.

Capabilities

Just-in-Time Compilation

Compiles functions to optimized XLA code for improved performance on CPUs, GPUs, and TPUs. JIT compilation happens lazily on first call and caches compiled functions.

def jit(
    fun: Callable,
    in_shardings=None,
    out_shardings=None,
    static_argnums=None,
    static_argnames=None,
    donate_argnums=None,
    donate_argnames=None,
    keep_unused=False,
    device=None,
    backend=None,
    inline=False,
    abstracted_axes=None
) -> Callable:
    """
    Just-in-time compile a function for improved performance.
    
    Args:
        fun: Function to JIT compile
        in_shardings: How inputs should be sharded across devices
        out_shardings: How outputs should be sharded across devices
        static_argnums: Tuple of argument indices to treat as static
        static_argnames: Tuple of keyword argument names to treat as static
        donate_argnums: Tuple of argument indices to donate (reuse memory)
        donate_argnames: Tuple of keyword argument names to donate
        keep_unused: Whether to keep unused arguments in compiled function
        device: Device to place computation on
        backend: Backend to use for compilation
        inline: Whether to inline the function
        abstracted_axes: Axes to abstract for shape polymorphism
        
    Returns:
        JIT-compiled function with same signature as input
    """

Usage example:

@jax.jit
def fast_computation(x, y):
    return jnp.sum(x ** 2 + y ** 2)

# Or with static arguments
@jax.jit(static_argnums=(1,))
def dynamic_slice(x, size):
    return x[:size]

Automatic Differentiation

Compute gradients of scalar-valued functions using reverse-mode automatic differentiation (backpropagation).

def grad(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False,
    allow_int: bool = False,
    reduce_axes: Sequence[int] = ()
) -> Callable:
    """
    Create function that computes gradient of scalar-valued function.
    
    Args:
        fun: Function to differentiate (must return scalar)
        argnums: Argument number(s) to differentiate with respect to
        has_aux: Whether function returns auxiliary data (value, aux)
        holomorphic: Whether function is holomorphic (complex differentiable)
        allow_int: Whether to allow integer inputs
        reduce_axes: Axes to reduce over when function output is not scalar
        
    Returns:
        Function that computes gradient with respect to specified arguments
    """

def value_and_grad(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False,
    allow_int: bool = False,
    reduce_axes: Sequence[int] = ()
) -> Callable:
    """
    Create function that computes both value and gradient.
    
    Args:
        fun: Function to differentiate
        argnums: Argument number(s) to differentiate with respect to  
        has_aux: Whether function returns auxiliary data
        holomorphic: Whether function is holomorphic
        allow_int: Whether to allow integer inputs
        reduce_axes: Axes to reduce over when function output is not scalar
        
    Returns:
        Function that returns (value, gradient) tuple
    """

Usage examples:

def loss_fn(params, x, y):
    predictions = params[0] * x + params[1]
    return jnp.mean((predictions - y) ** 2)

# Gradient function
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x, y)

# Value and gradient together
val_grad_fn = jax.value_and_grad(loss_fn)
loss_val, grads = val_grad_fn(params, x, y)

# Gradient with respect to multiple arguments
multi_grad_fn = jax.grad(loss_fn, argnums=(0, 1, 2))
param_grads, x_grads, y_grads = multi_grad_fn(params, x, y)

Jacobian Computation

Compute full Jacobian matrices using forward-mode or reverse-mode differentiation.

def jacobian(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False,
    allow_int: bool = False
) -> Callable:
    """
    Create function that computes Jacobian matrix.
    
    Args:
        fun: Function to compute Jacobian of
        argnums: Argument number(s) to differentiate with respect to
        has_aux: Whether function returns auxiliary data
        holomorphic: Whether function is holomorphic
        allow_int: Whether to allow integer inputs
        
    Returns:
        Function that returns Jacobian matrix
    """

def jacfwd(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False
) -> Callable:
    """
    Jacobian using forward-mode AD (efficient for tall Jacobians).
    
    Args:
        fun: Function to differentiate
        argnums: Argument number(s) to differentiate with respect to
        has_aux: Whether function returns auxiliary data
        holomorphic: Whether function is holomorphic
        
    Returns:
        Function that computes Jacobian using forward-mode AD
    """

def jacrev(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False
) -> Callable:
    """
    Jacobian using reverse-mode AD (efficient for wide Jacobians).
    
    Args:
        fun: Function to differentiate  
        argnums: Argument number(s) to differentiate with respect to
        has_aux: Whether function returns auxiliary data
        holomorphic: Whether function is holomorphic
        
    Returns:
        Function that computes Jacobian using reverse-mode AD
    """

def hessian(
    fun: Callable,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    holomorphic: bool = False
) -> Callable:
    """
    Create function that computes Hessian matrix (second derivatives).
    
    Args:
        fun: Scalar-valued function to compute Hessian of
        argnums: Argument number(s) to differentiate with respect to
        has_aux: Whether function returns auxiliary data
        holomorphic: Whether function is holomorphic
        
    Returns:
        Function that returns Hessian matrix
    """

Forward and Reverse Mode Primitives

Lower-level differentiation primitives for building custom transformations.

def jvp(
    fun: Callable,
    primals: Sequence,
    tangents: Sequence
) -> tuple:
    """
    Jacobian-vector product using forward-mode AD.
    
    Args:
        fun: Function to differentiate
        primals: Point at which to evaluate function
        tangents: Tangent vectors to multiply Jacobian by
        
    Returns:
        Tuple of (primals_out, tangents_out)
    """

def vjp(
    fun: Callable,
    *primals
) -> tuple:
    """
    Vector-Jacobian product using reverse-mode AD.
    
    Args:
        fun: Function to differentiate
        primals: Point at which to evaluate function
        
    Returns:
        Tuple of (primals_out, vjp_fun) where vjp_fun computes VJP
    """

def linearize(fun: Callable, *primals) -> tuple:
    """
    Linearize function around given point.
    
    Args:
        fun: Function to linearize
        primals: Point to linearize around
        
    Returns:
        Tuple of (primals_out, jvp_fun) for computing JVPs
    """

Vectorization

Transform functions to work on batches of inputs by adding a batch dimension and vectorizing over it.

def vmap(
    fun: Callable,
    in_axes=0,
    out_axes=0,
    axis_name=None,
    axis_size=None,
    spmd_axis_name=None
) -> Callable:
    """
    Vectorizing map that adds batch dimension to function.
    
    Args:
        fun: Function to vectorize
        in_axes: How to map over input arguments (int, None, or tuple)
        out_axes: How to map over output values (int, None, or tuple)
        axis_name: Name for the mapped axis (for use with psum etc.)
        axis_size: Size of mapped axis (for use with axis_name)
        spmd_axis_name: SPMD axis name for multi-device computation
        
    Returns:
        Vectorized function that works on batches
    """

Usage examples:

# Vectorize over first axis of both inputs
batch_fn = jax.vmap(single_example_fn)
batch_outputs = batch_fn(batch_inputs)

# Vectorize with different input axes
# x has batch dim 0, y has batch dim 1
fn = jax.vmap(process_fn, in_axes=(0, 1))

# Vectorize with no batch dim for some inputs
# x has batch dim 0, y is broadcast to all batch elements
fn = jax.vmap(process_fn, in_axes=(0, None))

Parallelization

Distribute computation across multiple devices using SPMD (Single Program, Multiple Data) parallelism.

def pmap(
    fun: Callable,
    axis_name=None,
    in_axes=0,
    out_axes=0,
    static_broadcasted_argnums=(),
    devices=None,
    backend=None,
    axis_size=None,
    donate_argnums=(),
    global_arg_shapes=None
) -> Callable:
    """
    Parallel map that distributes computation across multiple devices.
    
    Args:
        fun: Function to parallelize
        axis_name: Name for the parallel axis
        in_axes: How to split inputs across devices
        out_axes: How to collect outputs from devices
        static_broadcasted_argnums: Arguments to broadcast to all devices
        devices: Explicit device placement
        backend: Backend to use
        axis_size: Size of parallel axis
        donate_argnums: Arguments to donate memory
        global_arg_shapes: Global shapes for arguments
        
    Returns:
        Function that runs in parallel across devices
    """

Usage example:

# Function runs on each device with its slice of data
parallel_fn = jax.pmap(single_device_fn)
# Input shape: (num_devices, per_device_batch_size, ...)
outputs = parallel_fn(distributed_inputs)

Memory-Efficient Gradient Computation

Trade computation for memory using gradient checkpointing (rematerialization).

def checkpoint(
    fun: Callable,
    *,
    concrete: bool = False,
    policy: Callable = None,
    prevent_cse: bool = True,
    static_argnums: int | Sequence[int] = ()
) -> Callable:
    """
    Gradient checkpointing for memory-efficient backpropagation.
    
    Args:
        fun: Function to apply checkpointing to
        concrete: Whether to use concrete checkpointing
        policy: Policy for deciding what to checkpoint
        prevent_cse: Whether to prevent common subexpression elimination
        static_argnums: Arguments to treat as static
        
    Returns:
        Checkpointed function that saves memory during backward pass
    """

# Alias for checkpoint
remat = checkpoint

Usage example:

@jax.checkpoint
def expensive_layer(x, params):
    # Expensive computation that will be recomputed during backprop
    return jnp.tanh(x @ params)

# Use in gradient computation to save memory
grad_fn = jax.grad(lambda params: loss(checkpoint_layer(x, params)))

Custom Derivatives

Define custom forward and backward passes for functions.

def custom_gradient(fun: Callable) -> Callable:
    """
    Decorator to define custom gradient for function.
    
    The decorated function should return (primal_out, grad_fn) where
    grad_fn(cotangents) -> tangents.
    
    Args:
        fun: Function with custom gradient implementation
        
    Returns:
        Function with custom gradient behavior
    """

def custom_jvp(fun: Callable) -> Callable:
    """
    Decorator to define custom JVP (forward-mode derivative) rule.
    
    Args:
        fun: Function to define custom JVP for
        
    Returns:
        Function with custom JVP behavior
    """

def custom_vjp(fun: Callable) -> Callable:
    """
    Decorator to define custom VJP (reverse-mode derivative) rule.
    
    Args:
        fun: Function to define custom VJP for
        
    Returns:
        Function with custom VJP behavior
    """

Advanced Differentiation

Additional differentiation utilities and transformations.

def stop_gradient(x) -> Array:
    """
    Stop gradient computation at this point.
    
    Args:
        x: Array to stop gradient for
        
    Returns:
        Array with gradient flow stopped
    """

def fwd_and_bwd(
    fun: Callable,
    *primals,
    **kwargs
) -> tuple:
    """
    Compute forward and backward passes separately.
    
    Args:
        fun: Function to compute forward/backward for
        primals: Input values
        
    Returns:
        Tuple of (primal_out, vjp_fun) 
    """

def closure_convert(
    fun: Callable,
    *closed_over_vals
) -> tuple:
    """
    Convert function with closure variables for differentiation.
    
    Args:
        fun: Function with closure variables
        closed_over_vals: Values closed over by function
        
    Returns:
        Converted function and closure values
    """

def pure_callback(
    callback: Callable,
    result_shape_dtypes,
    *args,
    sharding=None,
    vmap_method=None,
    **kwargs
) -> Any:
    """
    Call host function with pure side effects from JAX computation.
    
    Args:
        callback: Pure host function to call
        result_shape_dtypes: Shape and dtype of callback result
        args: Arguments to pass to callback
        sharding: Sharding specification for result
        vmap_method: How to handle vectorization
        kwargs: Additional keyword arguments
        
    Returns:
        Result of callback with specified shape and dtype
    """

def effects_barrier() -> None:
    """
    Create synchronization barrier for side effects.
    
    Ensures all preceding computations with side effects complete
    before continuing with subsequent computations.
    """

def named_call(f: Callable, *, name: str) -> Callable:
    """
    Wrap function with a name for debugging and profiling.
    
    Args:
        f: Function to wrap
        name: Name to associate with function calls
        
    Returns:
        Wrapped function that appears with given name in traces
    """

def named_scope(name: str):
    """
    Context manager for named scopes in JAX computations.
    
    Args:
        name: Name for the computation scope
        
    Usage:
        with jax.named_scope("layer1"):
            output = layer_computation(input)
    """

Transformation Composition

JAX transformations can be arbitrarily composed for powerful effects:

# JIT-compiled gradient
fast_grad = jax.jit(jax.grad(loss_fn))

# Vectorized gradient (per-example gradients)  
batch_grad = jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0))

# Parallel gradient computation
parallel_grad = jax.pmap(jax.grad(loss_fn))

# Second derivatives (Hessian-vector product)
hvp = lambda v: jax.jvp(jax.grad(loss_fn), (params,), (v,))[1]

# Gradient of gradient (for meta-learning)
meta_grad = jax.grad(lambda meta_params: loss_fn(update_fn(meta_params)))

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json