Differentiate, compile, and transform Numpy code.
—
JAX LAX provides direct XLA operations and primitives for high-performance computing. These low-level functions offer precise control over computation and serve as building blocks for higher-level JAX operations.
import jax.lax as lax
from jax.lax import add, mul, dot_general, cond, scanElement-wise arithmetic operations that map directly to XLA primitives.
def add(x, y) -> Array:
"""Element-wise addition."""
def sub(x, y) -> Array:
"""Element-wise subtraction."""
def mul(x, y) -> Array:
"""Element-wise multiplication."""
def div(x, y) -> Array:
"""Element-wise division."""
def rem(x, y) -> Array:
"""Element-wise remainder."""
def max(x, y) -> Array:
"""Element-wise maximum."""
def min(x, y) -> Array:
"""Element-wise minimum."""
def abs(x) -> Array:
"""Element-wise absolute value."""
def neg(x) -> Array:
"""Element-wise negation."""
def sign(x) -> Array:
"""Element-wise sign function."""
def pow(x, y) -> Array:
"""Element-wise power operation."""
def integer_pow(x, y) -> Array:
"""Element-wise integer power."""
def reciprocal(x) -> Array:
"""Element-wise reciprocal (1/x)."""
def square(x) -> Array:
"""Element-wise square."""
def sqrt(x) -> Array:
"""Element-wise square root."""
def rsqrt(x) -> Array:
"""Element-wise reciprocal square root (1/√x)."""
def cbrt(x) -> Array:
"""Element-wise cube root."""
def clamp(min, x, max) -> Array:
"""
Clamp values between minimum and maximum.
Args:
min: Minimum value
x: Input array
max: Maximum value
Returns:
Array with values clamped to [min, max]
"""Transcendental and special mathematical functions.
# Trigonometric functions
def sin(x) -> Array: ...
def cos(x) -> Array: ...
def tan(x) -> Array: ...
def asin(x) -> Array: ...
def acos(x) -> Array: ...
def atan(x) -> Array: ...
def atan2(x, y) -> Array: ...
# Hyperbolic functions
def sinh(x) -> Array: ...
def cosh(x) -> Array: ...
def tanh(x) -> Array: ...
def asinh(x) -> Array: ...
def acosh(x) -> Array: ...
def atanh(x) -> Array: ...
# Exponential and logarithmic
def exp(x) -> Array: ...
def exp2(x) -> Array: ...
def expm1(x) -> Array: ...
def log(x) -> Array: ...
def log1p(x) -> Array: ...
def logistic(x) -> Array: ...
# Rounding operations
def ceil(x) -> Array: ...
def floor(x) -> Array: ...
def round(x) -> Array: ...
# Complex number operations
def complex(real, imag) -> Array:
"""Create complex array from real and imaginary parts."""
def conj(x) -> Array:
"""Complex conjugate."""
def real(x) -> Array:
"""Extract real part of complex array."""
def imag(x) -> Array:
"""Extract imaginary part of complex array."""Element-wise comparison operations returning boolean arrays.
def eq(x, y) -> Array:
"""Element-wise equality."""
def ne(x, y) -> Array:
"""Element-wise inequality."""
def lt(x, y) -> Array:
"""Element-wise less than."""
def le(x, y) -> Array:
"""Element-wise less than or equal."""
def gt(x, y) -> Array:
"""Element-wise greater than."""
def ge(x, y) -> Array:
"""Element-wise greater than or equal."""
def is_finite(x) -> Array:
"""Element-wise finite number test."""Bitwise operations on integer arrays.
# Bitwise operations
def bitwise_and(x, y) -> Array: ...
def bitwise_or(x, y) -> Array: ...
def bitwise_xor(x, y) -> Array: ...
def bitwise_not(x) -> Array: ...
# Bit shifting
def shift_left(x, y) -> Array: ...
def shift_right_logical(x, y) -> Array: ...
def shift_right_arithmetic(x, y) -> Array: ...
# Bit manipulation
def clz(x) -> Array:
"""Count leading zeros."""
def population_count(x) -> Array:
"""Count set bits."""Shape manipulation, broadcasting, and array transformation operations.
def broadcast(operand, sizes) -> Array:
"""Broadcast array by adding dimensions."""
def broadcast_in_dim(operand, shape, broadcast_dimensions) -> Array:
"""Broadcast array into target shape."""
def reshape(operand, new_sizes, dimensions=None) -> Array:
"""Reshape array to new dimensions."""
def transpose(operand, permutation) -> Array:
"""Transpose array axes."""
def rev(operand, dimensions) -> Array:
"""Reverse array along specified dimensions."""
def concatenate(operands, dimension) -> Array:
"""Concatenate arrays along dimension."""
def pad(operand, padding_value, padding_config) -> Array:
"""Pad array with constant value."""
def squeeze(array, dimensions) -> Array:
"""Remove unit dimensions."""
def expand_dims(array, dimensions) -> Array:
"""Add unit dimensions."""Advanced indexing operations for array access and updates.
def slice(operand, start_indices, limit_indices, strides=None) -> Array:
"""Extract slice from array."""
def slice_in_dim(operand, start, limit, stride=1, axis=0) -> Array:
"""Slice array along single dimension."""
def dynamic_slice(operand, start_indices, slice_sizes) -> Array:
"""Extract slice with dynamic start indices."""
def dynamic_slice_in_dim(operand, start, size, axis=0) -> Array:
"""Dynamic slice along single dimension."""
def dynamic_update_slice(operand, update, start_indices) -> Array:
"""Update slice with dynamic start indices."""
def dynamic_update_slice_in_dim(operand, update, start, axis) -> Array:
"""Dynamic update slice along single dimension."""
def gather(
operand,
start_indices,
dimension_numbers,
slice_sizes,
indices_are_sorted=False,
unique_indices=False,
mode=None,
fill_value=None
) -> Array:
"""General gather operation for advanced indexing."""
def scatter(
operand,
scatter_indices,
updates,
dimension_numbers,
indices_are_sorted=False,
unique_indices=False,
mode=None
) -> Array:
"""General scatter operation for advanced updates."""
# Scatter variants for different operations
def scatter_add(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_sub(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_mul(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_max(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_min(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def index_in_dim(operand, index, axis=0, keepdims=True) -> Array:
"""Index array along single dimension."""
def index_take(src, idxs, axes) -> Array:
"""Take elements using multi-dimensional indices."""Reduce arrays along specified axes using various operations.
def reduce(
operand,
init_value,
computation,
dimensions
) -> Array:
"""
General reduction operation.
Args:
operand: Array to reduce
init_value: Initial value for reduction
computation: Binary function for reduction
dimensions: Axes to reduce over
Returns:
Reduced array
"""
# Specialized reductions
def reduce_sum(operand, axes) -> Array: ...
def reduce_prod(operand, axes) -> Array: ...
def reduce_max(operand, axes) -> Array: ...
def reduce_min(operand, axes) -> Array: ...
def reduce_and(operand, axes) -> Array: ...
def reduce_or(operand, axes) -> Array: ...
def reduce_xor(operand, axes) -> Array: ...
# Windowed reductions
def reduce_window(
operand,
init_value,
computation,
window_dimensions,
window_strides=None,
padding=None,
base_dilation=None,
window_dilation=None
) -> Array:
"""
Sliding window reduction.
Args:
operand: Input array
init_value: Initial value for reduction
computation: Binary reduction function
window_dimensions: Size of sliding window
window_strides: Stride of sliding window
padding: Padding specification
base_dilation: Base dilation factor
window_dilation: Window dilation factor
Returns:
Reduced array with window operation applied
"""Conditional execution and loop constructs for dynamic computation graphs.
def cond(pred, true_fun, false_fun, *operands) -> Any:
"""
Conditional execution based on predicate.
Args:
pred: Boolean scalar predicate
true_fun: Function to execute if pred is True
false_fun: Function to execute if pred is False
operands: Arguments to pass to selected function
Returns:
Result of executing selected function
"""
def select(pred, on_true, on_false) -> Array:
"""Element-wise conditional selection."""
def select_n(which, *cases) -> Array:
"""Multi-way conditional selection."""
def while_loop(cond_fun, body_fun, init_val) -> Any:
"""
While loop with condition and body functions.
Args:
cond_fun: Function that returns boolean condition
body_fun: Function that updates loop state
init_val: Initial loop state
Returns:
Final loop state after termination
"""
def fori_loop(lower, upper, body_fun, init_val) -> Any:
"""
For loop over range with body function.
Args:
lower: Loop start index
upper: Loop end index (exclusive)
body_fun: Function that updates state (takes index and state)
init_val: Initial loop state
Returns:
Final loop state
"""
def scan(f, init, xs, length=None, reverse=False, unroll=1) -> tuple[Any, Array]:
"""
Scan operation applying function over sequence.
Args:
f: Function to apply (takes carry and input, returns new carry and output)
init: Initial carry value
xs: Input sequence
length: Length of sequence (inferred if None)
reverse: Whether to scan in reverse
unroll: Number of iterations to unroll
Returns:
Tuple of (final_carry, outputs)
"""
def associative_scan(fn, elems, reverse=False, axis=0) -> Array:
"""
Parallel associative scan operation.
Args:
fn: Associative binary function
elems: Input sequence
reverse: Whether to scan in reverse
axis: Axis to scan along
Returns:
Scanned results
"""
def switch(index, branches, *operands) -> Any:
"""
Switch statement for multi-way branching.
Args:
index: Integer index selecting branch
branches: List of functions (branches)
operands: Arguments to pass to selected branch
Returns:
Result of executing selected branch
"""
def map(f, xs) -> Array:
"""Map function over leading axis of array."""Cumulative operations along array axes.
def cumsum(operand, axis=None, reverse=False) -> Array:
"""Cumulative sum along axis."""
def cumprod(operand, axis=None, reverse=False) -> Array:
"""Cumulative product along axis."""
def cummax(operand, axis=None, reverse=False) -> Array:
"""Cumulative maximum along axis."""
def cummin(operand, axis=None, reverse=False) -> Array:
"""Cumulative minimum along axis."""
def cumlogsumexp(operand, axis=None, reverse=False) -> Array:
"""Cumulative log-sum-exp along axis."""Matrix operations and linear algebra primitives.
def dot(lhs, rhs, precision=None, preferred_element_type=None) -> Array:
"""Matrix multiplication for 1D and 2D arrays."""
def dot_general(
lhs,
rhs,
dimension_numbers,
precision=None,
preferred_element_type=None
) -> Array:
"""
General matrix multiplication with custom contractions.
Args:
lhs: Left-hand side array
rhs: Right-hand side array
dimension_numbers: Specification of contraction and batch dimensions
precision: Computation precision
preferred_element_type: Preferred output element type
Returns:
Result of general matrix multiplication
"""
def batch_matmul(
lhs,
rhs,
precision=None,
preferred_element_type=None
) -> Array:
"""Batched matrix multiplication."""
class DotDimensionNumbers:
"""Dimension specification for dot_general operation."""
lhs_contracting_dimensions: tuple[int, ...]
rhs_contracting_dimensions: tuple[int, ...]
lhs_batch_dimensions: tuple[int, ...]
rhs_batch_dimensions: tuple[int, ...]Advanced linear algebra operations from jax.lax.linalg.
def cholesky(a, *, symmetrize_input: bool = True) -> Array:
"""
Cholesky decomposition of positive definite matrix.
Args:
a: Positive definite matrix
symmetrize_input: Whether to symmetrize input
Returns:
Lower triangular Cholesky factor
"""
def cholesky_update(r, u, *, alpha: float = 1.0) -> Array:
"""
Rank-1 update to Cholesky factorization.
Args:
r: Cholesky factor
u: Update vector
alpha: Update coefficient
Returns:
Updated Cholesky factor
"""
def eig(a, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> tuple[Array, Array, Array]:
"""
Eigenvalue decomposition of general matrix.
Args:
a: Input matrix
compute_left_eigenvectors: Whether to compute left eigenvectors
compute_right_eigenvectors: Whether to compute right eigenvectors
Returns:
Tuple of (eigenvalues, left_eigenvectors, right_eigenvectors)
"""
def eigh(a, *, lower: bool = True, symmetrize_input: bool = True, sort_eigenvalues: bool = True) -> tuple[Array, Array]:
"""
Eigenvalue decomposition of Hermitian matrix.
Args:
a: Hermitian matrix
lower: Whether to use lower triangle
symmetrize_input: Whether to symmetrize input
sort_eigenvalues: Whether to sort eigenvalues
Returns:
Tuple of (eigenvalues, eigenvectors)
"""
def lu(a) -> tuple[Array, Array, Array]:
"""
LU decomposition with partial pivoting.
Args:
a: Input matrix
Returns:
Tuple of (lu_factors, pivots, permutation)
"""
def qr(a, *, full_matrices: bool = True) -> tuple[Array, Array]:
"""
QR decomposition.
Args:
a: Input matrix
full_matrices: Whether to return full or reduced QR
Returns:
Tuple of (q, r) matrices
"""
def svd(a, *, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> tuple[Array, Array, Array]:
"""
Singular value decomposition.
Args:
a: Input matrix
full_matrices: Whether to return full or reduced SVD
compute_uv: Whether to compute U and V matrices
hermitian: Whether matrix is Hermitian
Returns:
Tuple of (u, s, vh) where A = U @ diag(s) @ Vh
"""
def schur(a, *, compute_schur_vectors: bool = True, sort_eigs: bool = False, select_callable=None) -> tuple[Array, Array]:
"""
Schur decomposition.
Args:
a: Input matrix
compute_schur_vectors: Whether to compute Schur vectors
sort_eigs: Whether to sort eigenvalues
select_callable: Selection function for eigenvalues
Returns:
Tuple of (schur_form, schur_vectors)
"""
def hessenberg(a) -> tuple[Array, Array]:
"""
Hessenberg decomposition.
Args:
a: Input matrix
Returns:
Tuple of (hessenberg_form, orthogonal_matrix)
"""
def triangular_solve(a, b, *, left_side: bool = True, lower: bool = True, transpose_a: bool = False, conjugate_a: bool = False, unit_diagonal: bool = False) -> Array:
"""
Solve triangular system of equations.
Args:
a: Triangular matrix
b: Right-hand side
left_side: Whether A is on left side (Ax = b) or right (xA = b)
lower: Whether A is lower triangular
transpose_a: Whether to transpose A
conjugate_a: Whether to conjugate A
unit_diagonal: Whether A has unit diagonal
Returns:
Solution to triangular system
"""
def tridiagonal(a, d, *, lower: bool = True) -> tuple[Array, Array]:
"""
Tridiagonal reduction of symmetric matrix.
Args:
a: Symmetric matrix
d: Diagonal elements
lower: Whether to use lower triangle
Returns:
Tuple of (tridiagonal_matrix, orthogonal_matrix)
"""
def tridiagonal_solve(dl, d, du, b) -> Array:
"""
Solve tridiagonal system using Thomas algorithm.
Args:
dl: Lower diagonal
d: Main diagonal
du: Upper diagonal
b: Right-hand side
Returns:
Solution to tridiagonal system
"""
def qdwh(a, *, is_hermitian: bool = False, max_iterations: int = None, dynamic_shape: bool = False) -> tuple[Array, Array]:
"""
QDWH polar decomposition: A = UP where U is unitary, P is positive semidefinite.
Args:
a: Input matrix
is_hermitian: Whether matrix is Hermitian
max_iterations: Maximum number of iterations
dynamic_shape: Whether to handle dynamic shapes
Returns:
Tuple of (unitary_factor, positive_factor)
"""
def householder_product(a, taus) -> Array:
"""
Compute product of Householder reflectors.
Args:
a: Matrix containing Householder vectors
taus: Householder scaling factors
Returns:
Product of Householder reflectors
"""
def lu_pivots_to_permutation(pivots, permutation_size) -> Array:
"""
Convert LU pivots to permutation matrix.
Args:
pivots: Pivot indices from LU decomposition
permutation_size: Size of permutation matrix
Returns:
Permutation matrix
"""Convolution operations for neural networks and signal processing.
def conv(
lhs,
rhs,
window_strides,
padding,
precision=None,
preferred_element_type=None
) -> Array:
"""Basic convolution operation."""
def conv_general_dilated(
lhs,
rhs,
window_strides,
padding,
lhs_dilation=None,
rhs_dilation=None,
dimension_numbers=None,
feature_group_count=1,
batch_group_count=1,
precision=None,
preferred_element_type=None
) -> Array:
"""
General dilated convolution with full configuration options.
Args:
lhs: Input array (N...HWC or NCHW... format)
rhs: Kernel array
window_strides: Convolution strides
padding: Padding specification
lhs_dilation: Input dilation
rhs_dilation: Kernel dilation (atrous convolution)
dimension_numbers: Dimension layout specification
feature_group_count: Number of feature groups
batch_group_count: Number of batch groups
precision: Computation precision
preferred_element_type: Preferred output type
Returns:
Convolution result
"""
def conv_transpose(
lhs,
rhs,
strides,
padding,
rhs_dilation=None,
dimension_numbers=None,
transpose_kernel=False,
precision=None,
preferred_element_type=None
) -> Array:
"""Transposed (deconvolution) operation."""
class ConvDimensionNumbers:
"""Convolution dimension number specification."""
lhs_spec: tuple[int, ...] # Input dimension specification
rhs_spec: tuple[int, ...] # Kernel dimension specification
out_spec: tuple[int, ...] # Output dimension specificationFast Fourier Transform operations.
def fft(a, fft_type, fft_lengths) -> Array:
"""
Fast Fourier Transform.
Args:
a: Input array
fft_type: Type of FFT (from FftType enum)
fft_lengths: Lengths of FFT dimensions
Returns:
FFT result
"""
class FftType:
"""FFT type enumeration."""
FFT = "FFT"
IFFT = "IFFT"
RFFT = "RFFT"
IRFFT = "IRFFT"Multi-device communication primitives for distributed computing.
def all_gather(x, axis_name, *, axis_index_groups=None, tiled=False) -> Array:
"""Gather values from all devices."""
def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False) -> Array:
"""All-to-all communication between devices."""
def psum(x, axis_name, *, axis_index_groups=None) -> Array:
"""Parallel sum reduction across devices."""
def pmean(x, axis_name, *, axis_index_groups=None) -> Array:
"""Parallel mean reduction across devices."""
def pmax(x, axis_name, *, axis_index_groups=None) -> Array:
"""Parallel max reduction across devices."""
def pmin(x, axis_name, *, axis_index_groups=None) -> Array:
"""Parallel min reduction across devices."""
def ppermute(x, axis_name, perm, *, axis_index_groups=None) -> Array:
"""Permute data between devices."""
def axis_index(axis_name) -> Array:
"""Get device index along named axis."""
def axis_size(axis_name) -> int:
"""Get number of devices along named axis."""
def pbroadcast(x, axis_name, *, axis_index_groups=None) -> Array:
"""Broadcast from first device to all others."""Special mathematical functions and probability distributions.
# Error functions
def erf(x) -> Array: ...
def erfc(x) -> Array: ...
def erf_inv(x) -> Array: ...
# Gamma functions
def lgamma(x) -> Array: ...
def digamma(x) -> Array: ...
def polygamma(m, x) -> Array: ...
# Bessel functions
def bessel_i0e(x) -> Array: ...
def bessel_i1e(x) -> Array: ...
# Other special functions
def betainc(a, b, x) -> Array: ...
def igamma(a, x) -> Array: ...
def igammac(a, x) -> Array: ...
def zeta(x, q=None) -> Array: ...Array type conversion and data manipulation operations.
def convert_element_type(operand, new_dtype) -> Array:
"""Convert array element type."""
def bitcast_convert_type(operand, new_dtype) -> Array:
"""Bitcast array to new type without changing bit representation."""
def dtype(x) -> numpy.dtype:
"""Get array data type."""
def full(shape, fill_value, dtype=None) -> Array:
"""Create array filled with constant value."""
def full_like(x, fill_value, dtype=None, shape=None) -> Array:
"""Create filled array with same properties as input."""
def iota(dtype, size) -> Array:
"""Create array with sequential values (0, 1, 2, ...)."""
def broadcasted_iota(dtype, shape, dimension) -> Array:
"""Create iota array broadcasted to shape."""Sorting and selection operations.
def sort(operand, dimension=-1, is_stable=True) -> Array:
"""Sort array along dimension."""
def sort_key_val(keys, values, dimension=-1, is_stable=True) -> tuple[Array, Array]:
"""Sort key-value pairs."""
def top_k(operand, k) -> tuple[Array, Array]:
"""Find top k largest elements and their indices."""
def argmax(operand, axis=None, index_dtype=int) -> Array:
"""Indices of maximum values."""
def argmin(operand, axis=None, index_dtype=int) -> Array:
"""Indices of minimum values."""Additional utility operations and performance primitives.
def stop_gradient(x) -> Array:
"""Stop gradient computation at this point."""
def optimization_barrier(x) -> Array:
"""Prevent optimization across this point."""
def nextafter(x1, x2) -> Array:
"""Next representable value after x1 in direction of x2."""
def reduce_precision(operand, exponent_bits, mantissa_bits) -> Array:
"""Reduce floating-point precision."""
def create_token() -> Array:
"""Create execution token for ordering side effects."""
def after_all(*tokens) -> Array:
"""Create token that depends on all input tokens."""
# Random number generation primitives
def rng_uniform(a, b, shape, dtype=None) -> Array:
"""Low-level uniform random number generation."""
def rng_bit_generator(key, shape, dtype=None, algorithm=None) -> tuple[Array, Array]:
"""Low-level random bit generation."""Install with Tessl CLI
npx tessl i tessl/pypi-jax