Differentiate, compile, and transform Numpy code.
—
JAX provides a comprehensive NumPy-compatible API through jax.numpy (commonly imported as jnp). JAX arrays are immutable and support the full NumPy API with added benefits of JIT compilation, automatic differentiation, and device acceleration.
import jax.numpy as jnp
import jaxCreate JAX arrays from various data sources and specifications.
def array(object, dtype=None, copy=None, order=None, ndmin=0) -> Array:
"""Create array from array-like object."""
def asarray(a, dtype=None, order=None) -> Array:
"""Convert input to array."""
def zeros(shape, dtype=None) -> Array:
"""Create array filled with zeros."""
def zeros_like(a, dtype=None, shape=None) -> Array:
"""Create zeros array with same shape as input."""
def ones(shape, dtype=None) -> Array:
"""Create array filled with ones."""
def ones_like(a, dtype=None, shape=None) -> Array:
"""Create ones array with same shape as input."""
def full(shape, fill_value, dtype=None) -> Array:
"""Create array filled with constant value."""
def full_like(a, fill_value, dtype=None, shape=None) -> Array:
"""Create filled array with same shape as input."""
def empty(shape, dtype=None) -> Array:
"""Create uninitialized array."""
def empty_like(a, dtype=None, shape=None) -> Array:
"""Create empty array with same shape as input."""
def eye(N, M=None, k=0, dtype=None) -> Array:
"""Create identity matrix."""
def identity(n, dtype=None) -> Array:
"""Create square identity matrix."""
def arange(start, stop=None, step=None, dtype=None) -> Array:
"""Create evenly spaced values within interval."""
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0) -> Array:
"""Create evenly spaced numbers over interval."""
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0) -> Array:
"""Create numbers spaced evenly on log scale."""
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0) -> Array:
"""Create numbers spaced evenly on log scale (geometric progression)."""
def meshgrid(*xi, copy=True, sparse=False, indexing='xy') -> list[Array]:
"""Create coordinate matrices from coordinate vectors."""
def mgrid() -> MGridClass:
"""Multi-dimensional mesh creation."""
def ogrid() -> OGridClass:
"""Open multi-dimensional mesh creation."""
def indices(dimensions, dtype=int, sparse=False) -> Array:
"""Create arrays of indices."""
def tri(N, M=None, k=0, dtype=None) -> Array:
"""Create array with ones at and below diagonal."""Element-wise mathematical operations following NumPy conventions.
# Arithmetic operations
def add(x1, x2) -> Array: ...
def subtract(x1, x2) -> Array: ...
def multiply(x1, x2) -> Array: ...
def divide(x1, x2) -> Array: ...
def true_divide(x1, x2) -> Array: ...
def floor_divide(x1, x2) -> Array: ...
def power(x1, x2) -> Array: ...
def float_power(x1, x2) -> Array: ...
def mod(x1, x2) -> Array: ...
def remainder(x1, x2) -> Array: ...
def divmod(x1, x2) -> tuple[Array, Array]: ...
# Trigonometric functions
def sin(x) -> Array: ...
def cos(x) -> Array: ...
def tan(x) -> Array: ...
def asin(x) -> Array: ...
def acos(x) -> Array: ...
def atan(x) -> Array: ...
def atan2(x1, x2) -> Array: ...
def sinh(x) -> Array: ...
def cosh(x) -> Array: ...
def tanh(x) -> Array: ...
def asinh(x) -> Array: ...
def acosh(x) -> Array: ...
def atanh(x) -> Array: ...
def degrees(x) -> Array: ...
def radians(x) -> Array: ...
def deg2rad(x) -> Array: ...
def rad2deg(x) -> Array: ...
# Exponential and logarithmic
def exp(x) -> Array: ...
def exp2(x) -> Array: ...
def expm1(x) -> Array: ...
def log(x) -> Array: ...
def log10(x) -> Array: ...
def log2(x) -> Array: ...
def log1p(x) -> Array: ...
# Rounding and precision
def round(a, decimals=0) -> Array: ...
def rint(x) -> Array: ...
def fix(x) -> Array: ...
def floor(x) -> Array: ...
def ceil(x) -> Array: ...
def trunc(x) -> Array: ...
# Arithmetic functions
def abs(x) -> Array: ...
def absolute(x) -> Array: ...
def fabs(x) -> Array: ...
def sign(x) -> Array: ...
def signbit(x) -> Array: ...
def copysign(x1, x2) -> Array: ...
def sqrt(x) -> Array: ...
def square(x) -> Array: ...
def cbrt(x) -> Array: ...
def reciprocal(x) -> Array: ...
def positive(x) -> Array: ...
def negative(x) -> Array: ...
# Extrema functions
def maximum(x1, x2) -> Array: ...
def minimum(x1, x2) -> Array: ...
def fmax(x1, x2) -> Array: ...
def fmin(x1, x2) -> Array: ...
def clip(a, a_min=None, a_max=None) -> Array: ...
# Complex number functions
def real(val) -> Array: ...
def imag(val) -> Array: ...
def conj(x) -> Array: ...
def conjugate(x) -> Array: ...
def angle(z, deg=False) -> Array: ...
def isreal(x) -> Array: ...
def iscomplex(x) -> Array: ...
# Floating point functions
def isfinite(x) -> Array: ...
def isinf(x) -> Array: ...
def isnan(x) -> Array: ...
def isneginf(x) -> Array: ...
def isposinf(x) -> Array: ...
def nextafter(x1, x2) -> Array: ...
def spacing(x) -> Array: ...
def modf(x) -> tuple[Array, Array]: ...
def frexp(x) -> tuple[Array, Array]: ...
def ldexp(x1, x2) -> Array: ...Functions for reshaping, combining, and transforming arrays.
# Shape manipulation
def reshape(a, newshape, order='C') -> Array: ...
def ravel(a, order='C') -> Array: ...
def flatten(a, order='C') -> Array: ...
# Transpose operations
def transpose(a, axes=None) -> Array: ...
def swapaxes(a, axis1, axis2) -> Array: ...
def moveaxis(a, source, destination) -> Array: ...
def rollaxis(a, axis, start=0) -> Array: ...
# Dimension manipulation
def expand_dims(a, axis) -> Array: ...
def squeeze(a, axis=None) -> Array: ...
# Array reversal and rotation
def flip(m, axis=None) -> Array: ...
def fliplr(m) -> Array: ...
def flipud(m) -> Array: ...
def rot90(m, k=1, axes=(0, 1)) -> Array: ...
def roll(a, shift, axis=None) -> Array: ...
# Broadcasting
def broadcast_to(array, shape) -> Array: ...
def broadcast_arrays(*args) -> list[Array]: ...
# Joining arrays
def concatenate(arrays, axis=0) -> Array: ...
def stack(arrays, axis=0) -> Array: ...
def vstack(tup) -> Array: ...
def hstack(tup) -> Array: ...
def dstack(tup) -> Array: ...
def column_stack(tup) -> Array: ...
def append(arr, values, axis=None) -> Array: ...
# Splitting arrays
def split(ary, indices_or_sections, axis=0) -> list[Array]: ...
def array_split(ary, indices_or_sections, axis=0) -> list[Array]: ...
def hsplit(ary, indices_or_sections) -> list[Array]: ...
def vsplit(ary, indices_or_sections) -> list[Array]: ...
def dsplit(ary, indices_or_sections) -> list[Array]: ...
# Tiling and repeating
def tile(A, reps) -> Array: ...
def repeat(a, repeats, axis=None) -> Array: ...
# Array modification
def insert(arr, obj, values, axis=None) -> Array: ...
def delete(arr, obj, axis=None) -> Array: ...
def place(arr, mask, vals) -> None: ...
def put(a, ind, v, mode='raise') -> None: ...
def put_along_axis(arr, indices, values, axis) -> None: ...
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, equal_nan=True) -> Array: ...Advanced indexing, selection, and conditional operations.
def take(a, indices, axis=None, mode=None) -> Array:
"""Take elements from array along axis."""
def take_along_axis(arr, indices, axis) -> Array:
"""Take values from array using indices along axis."""
def choose(a, choices, mode='raise') -> Array:
"""Construct array from index array and choice arrays."""
def compress(condition, a, axis=None) -> Array:
"""Return selected slices along axis."""
def extract(condition, arr) -> Array:
"""Return elements satisfying condition."""
def select(condlist, choicelist, default=0) -> Array:
"""Return elements chosen from choicelist based on conditions."""
def where(condition, x=None, y=None) -> Array:
"""Return elements chosen from x or y based on condition."""
def nonzero(a) -> tuple[Array, ...]:
"""Return indices of non-zero elements."""
def argwhere(a) -> Array:
"""Return indices where condition is True."""
def flatnonzero(a) -> Array:
"""Return indices of flattened array that are non-zero."""
def ix_(*args) -> tuple[Array, ...]:
"""Construct open mesh from multiple sequences."""Functions that reduce arrays along axes or compute aggregates.
# Basic reductions
def sum(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...
def prod(a, axis=None, dtype=None, keepdims=False, initial=None, where=None) -> Array: ...
def mean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
def median(a, axis=None, keepdims=False) -> Array: ...
def std(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
def var(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
# Extrema
def min(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def max(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def amin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def amax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def ptp(a, axis=None, keepdims=False) -> Array: ...
# Percentiles and quantiles
def percentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
def quantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
# Cumulative operations
def cumsum(a, axis=None, dtype=None) -> Array: ...
def cumprod(a, axis=None, dtype=None) -> Array: ...
# Logical reductions
def all(a, axis=None, keepdims=False, where=None) -> Array: ...
def any(a, axis=None, keepdims=False, where=None) -> Array: ...
# Counting
def count_nonzero(a, axis=None, keepdims=False) -> Array: ...
# NaN-aware reductions
def nansum(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
def nanprod(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
def nanmean(a, axis=None, dtype=None, keepdims=False, where=None) -> Array: ...
def nanmedian(a, axis=None, keepdims=False) -> Array: ...
def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False, where=None) -> Array: ...
def nanmin(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def nanmax(a, axis=None, keepdims=False, initial=None, where=None) -> Array: ...
def nanpercentile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
def nanquantile(a, q, axis=None, method='linear', keepdims=False) -> Array: ...
def nancumsum(a, axis=None, dtype=None) -> Array: ...
def nancumprod(a, axis=None, dtype=None) -> Array: ...
# Indices of extrema
def argmin(a, axis=None, keepdims=False) -> Array: ...
def argmax(a, axis=None, keepdims=False) -> Array: ...
def nanargmin(a, axis=None, keepdims=False) -> Array: ...
def nanargmax(a, axis=None, keepdims=False) -> Array: ...Core linear algebra operations for matrix computations.
# Matrix multiplication
def dot(a, b) -> Array: ...
def matmul(x1, x2) -> Array: ...
def inner(a, b) -> Array: ...
def outer(a, b) -> Array: ...
def tensordot(a, b, axes=2) -> Array: ...
def kron(a, b) -> Array: ...
# Vector operations
def vdot(a, b) -> Array: ...
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None) -> Array: ...
# Matrix operations
def trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array: ...
def diagonal(a, offset=0, axis1=0, axis2=1) -> Array: ...
def diag(v, k=0) -> Array: ...
def diagflat(v, k=0) -> Array: ...
# Triangular matrices
def tril(m, k=0) -> Array: ...
def triu(m, k=0) -> Array: ...
def tril_indices(n, k=0, m=None) -> tuple[Array, Array]: ...
def triu_indices(n, k=0, m=None) -> tuple[Array, Array]: ...
def diag_indices(n, ndim=2) -> tuple[Array, ...]: ...
# Matrix transpose
def matrix_transpose(x) -> Array: ...Functions for sorting arrays and searching for values.
def sort(a, axis=-1, kind='stable', order=None) -> Array: ...
def argsort(a, axis=-1, kind='stable', order=None) -> Array: ...
def lexsort(keys, axis=-1) -> Array: ...
def partition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argpartition(a, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def searchsorted(a, v, side='left', sorter=None) -> Array: ...
def sort_complex(a) -> Array: ...Set-like operations on arrays.
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None) -> Array: ...
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False) -> Array: ...
def union1d(ar1, ar2) -> Array: ...
def setdiff1d(ar1, ar2, assume_unique=False) -> Array: ...
def setxor1d(ar1, ar2, assume_unique=False) -> Array: ...
def isin(element, test_elements, assume_unique=False, invert=False) -> Array: ...Statistical analysis and distribution functions.
def bincount(x, weights=None, minlength=0, length=None) -> Array: ...
def histogram(a, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array]: ...
def histogram2d(x, y, bins=10, range=None, weights=None, density=None) -> tuple[Array, Array, Array]: ...
def histogramdd(sample, bins=10, range=None, weights=None, density=None) -> tuple[Array, list[Array]]: ...
def histogram_bin_edges(a, bins=10, range=None, weights=None) -> Array: ...
def digitize(x, bins, right=False) -> Array: ...
def average(a, axis=None, weights=None, returned=False, keepdims=False) -> Array: ...
def corrcoef(x, y=None, rowvar=True, dtype=None) -> Array: ...
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None) -> Array: ...
def gradient(f, *varargs, axis=None, edge_order=1) -> Array: ...Type information, checking, and conversion functions.
# Type checking
def issubdtype(arg1, arg2) -> bool: ...
def can_cast(from_, to, casting='safe') -> bool: ...
def result_type(*arrays_and_dtypes): ...
def promote_types(type1, type2): ...
def isscalar(element) -> bool: ...
def isrealobj(x) -> bool: ...
def iscomplexobj(x) -> bool: ...
# Type information
def finfo(dtype): ...
def iinfo(dtype): ...
# Array properties
def ndim(a) -> int: ...
def shape(a) -> tuple: ...
def size(a) -> int: ...
# Comparison functions
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool: ...
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) -> Array: ...
def array_equal(a1, a2, equal_nan=False) -> bool: ...
def array_equiv(a1, a2) -> bool: ...
# Utility functions
def copy(a, order='K') -> Array: ...
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None) -> Array: ...Element-wise comparison functions returning boolean arrays.
def equal(x1, x2) -> Array: ...
def not_equal(x1, x2) -> Array: ...
def less(x1, x2) -> Array: ...
def less_equal(x1, x2) -> Array: ...
def greater(x1, x2) -> Array: ...
def greater_equal(x1, x2) -> Array: ...Element-wise logical operations on boolean arrays.
def logical_and(x1, x2) -> Array: ...
def logical_or(x1, x2) -> Array: ...
def logical_not(x) -> Array: ...
def logical_xor(x1, x2) -> Array: ...Element-wise bitwise operations on integer arrays.
def bitwise_and(x1, x2) -> Array: ...
def bitwise_or(x1, x2) -> Array: ...
def bitwise_xor(x1, x2) -> Array: ...
def bitwise_not(x) -> Array: ...
def bitwise_left_shift(x1, x2) -> Array: ...
def bitwise_right_shift(x1, x2) -> Array: ...
def left_shift(x1, x2) -> Array: ...
def right_shift(x1, x2) -> Array: ...
def invert(x) -> Array: ...
def bitwise_count(x) -> Array: ...Mathematical and numerical constants.
pi: float # π (3.14159...)
e: float # Euler's number (2.71828...)
euler_gamma: float # Euler-Mascheroni constant
inf: float # Positive infinity
nan: float # Not a Number
newaxis: None # Used for adding dimensions in indexingimport jax.numpy.fft as jfft
# 1D transforms
jfft.fft(a, n=None, axis=-1, norm=None) -> Array
jfft.ifft(a, n=None, axis=-1, norm=None) -> Array
jfft.rfft(a, n=None, axis=-1, norm=None) -> Array
jfft.irfft(a, n=None, axis=-1, norm=None) -> Array
# 2D transforms
jfft.fft2(a, s=None, axes=(-2, -1), norm=None) -> Array
jfft.ifft2(a, s=None, axes=(-2, -1), norm=None) -> Array
jfft.rfft2(a, s=None, axes=(-2, -1), norm=None) -> Array
jfft.irfft2(a, s=None, axes=(-2, -1), norm=None) -> Array
# N-D transforms
jfft.fftn(a, s=None, axes=None, norm=None) -> Array
jfft.ifftn(a, s=None, axes=None, norm=None) -> Array
jfft.rfftn(a, s=None, axes=None, norm=None) -> Array
jfft.irfftn(a, s=None, axes=None, norm=None) -> Array
# Helper functions
jfft.fftfreq(n, d=1.0) -> Array
jfft.rfftfreq(n, d=1.0) -> Array
jfft.fftshift(x, axes=None) -> Array
jfft.ifftshift(x, axes=None) -> Arrayimport jax.numpy.linalg as jla
# Matrix decompositions
jla.cholesky(a) -> Array
jla.qr(a, mode='reduced') -> tuple[Array, Array]
jla.svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]
jla.eig(a) -> tuple[Array, Array]
jla.eigh(a, UPLO='L') -> tuple[Array, Array]
jla.eigvals(a) -> Array
jla.eigvalsh(a, UPLO='L') -> Array
# Matrix properties
jla.det(a) -> Array
jla.slogdet(a) -> tuple[Array, Array]
jla.matrix_rank(M, tol=None, hermitian=False) -> Array
jla.trace(a, offset=0, axis1=0, axis2=1, dtype=None) -> Array
# Matrix solutions
jla.solve(a, b) -> Array
jla.lstsq(a, b, rcond=None) -> tuple[Array, Array, Array, Array]
jla.inv(a) -> Array
jla.pinv(a, rcond=None, hermitian=False) -> Array
# Norms and distances
jla.norm(x, ord=None, axis=None, keepdims=False) -> Array
jla.cond(x, p=None) -> Array
# Matrix functions
jla.matrix_power(a, n) -> ArrayInstall with Tessl CLI
npx tessl i tessl/pypi-jax