CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

low-level-ops.mddocs/

Low-Level Operations

JAX LAX provides direct XLA operations and primitives for high-performance computing. These low-level functions offer precise control over computation and serve as building blocks for higher-level JAX operations.

Core Imports

import jax.lax as lax
from jax.lax import add, mul, dot_general, cond, scan

Capabilities

Arithmetic Operations

Element-wise arithmetic operations that map directly to XLA primitives.

def add(x, y) -> Array:
    """Element-wise addition."""

def sub(x, y) -> Array:
    """Element-wise subtraction."""

def mul(x, y) -> Array:
    """Element-wise multiplication."""

def div(x, y) -> Array:
    """Element-wise division."""

def rem(x, y) -> Array:
    """Element-wise remainder."""

def max(x, y) -> Array:
    """Element-wise maximum."""

def min(x, y) -> Array:
    """Element-wise minimum."""

def abs(x) -> Array:
    """Element-wise absolute value."""

def neg(x) -> Array:
    """Element-wise negation."""

def sign(x) -> Array:
    """Element-wise sign function."""

def pow(x, y) -> Array:
    """Element-wise power operation."""

def integer_pow(x, y) -> Array:
    """Element-wise integer power."""

def reciprocal(x) -> Array:
    """Element-wise reciprocal (1/x)."""

def square(x) -> Array:
    """Element-wise square."""

def sqrt(x) -> Array:
    """Element-wise square root."""

def rsqrt(x) -> Array:
    """Element-wise reciprocal square root (1/√x)."""

def cbrt(x) -> Array:
    """Element-wise cube root."""

def clamp(min, x, max) -> Array:
    """
    Clamp values between minimum and maximum.
    
    Args:
        min: Minimum value
        x: Input array
        max: Maximum value
        
    Returns:
        Array with values clamped to [min, max]
    """

Mathematical Functions

Transcendental and special mathematical functions.

# Trigonometric functions
def sin(x) -> Array: ...
def cos(x) -> Array: ...
def tan(x) -> Array: ...
def asin(x) -> Array: ...
def acos(x) -> Array: ...
def atan(x) -> Array: ...
def atan2(x, y) -> Array: ...

# Hyperbolic functions
def sinh(x) -> Array: ...
def cosh(x) -> Array: ...
def tanh(x) -> Array: ...
def asinh(x) -> Array: ...
def acosh(x) -> Array: ...
def atanh(x) -> Array: ...

# Exponential and logarithmic  
def exp(x) -> Array: ...
def exp2(x) -> Array: ...
def expm1(x) -> Array: ...
def log(x) -> Array: ...
def log1p(x) -> Array: ...
def logistic(x) -> Array: ...

# Rounding operations
def ceil(x) -> Array: ...
def floor(x) -> Array: ...
def round(x) -> Array: ...

# Complex number operations
def complex(real, imag) -> Array:
    """Create complex array from real and imaginary parts."""

def conj(x) -> Array:
    """Complex conjugate."""

def real(x) -> Array:
    """Extract real part of complex array."""

def imag(x) -> Array:
    """Extract imaginary part of complex array."""

Comparison Operations

Element-wise comparison operations returning boolean arrays.

def eq(x, y) -> Array:
    """Element-wise equality."""

def ne(x, y) -> Array:
    """Element-wise inequality."""

def lt(x, y) -> Array:
    """Element-wise less than."""

def le(x, y) -> Array:
    """Element-wise less than or equal."""

def gt(x, y) -> Array:
    """Element-wise greater than."""

def ge(x, y) -> Array:
    """Element-wise greater than or equal."""

def is_finite(x) -> Array:
    """Element-wise finite number test."""

Bitwise Operations

Bitwise operations on integer arrays.

# Bitwise operations  
def bitwise_and(x, y) -> Array: ...
def bitwise_or(x, y) -> Array: ...
def bitwise_xor(x, y) -> Array: ...
def bitwise_not(x) -> Array: ...

# Bit shifting
def shift_left(x, y) -> Array: ...
def shift_right_logical(x, y) -> Array: ...
def shift_right_arithmetic(x, y) -> Array: ...

# Bit manipulation
def clz(x) -> Array:
    """Count leading zeros."""

def population_count(x) -> Array:
    """Count set bits."""

Array Operations

Shape manipulation, broadcasting, and array transformation operations.

def broadcast(operand, sizes) -> Array:
    """Broadcast array by adding dimensions."""

def broadcast_in_dim(operand, shape, broadcast_dimensions) -> Array:
    """Broadcast array into target shape."""

def reshape(operand, new_sizes, dimensions=None) -> Array:
    """Reshape array to new dimensions."""

def transpose(operand, permutation) -> Array:
    """Transpose array axes."""

def rev(operand, dimensions) -> Array:
    """Reverse array along specified dimensions."""

def concatenate(operands, dimension) -> Array:
    """Concatenate arrays along dimension."""

def pad(operand, padding_value, padding_config) -> Array:
    """Pad array with constant value."""

def squeeze(array, dimensions) -> Array:
    """Remove unit dimensions."""

def expand_dims(array, dimensions) -> Array:
    """Add unit dimensions."""

Indexing and Slicing

Advanced indexing operations for array access and updates.

def slice(operand, start_indices, limit_indices, strides=None) -> Array:
    """Extract slice from array."""

def slice_in_dim(operand, start, limit, stride=1, axis=0) -> Array:
    """Slice array along single dimension."""

def dynamic_slice(operand, start_indices, slice_sizes) -> Array:
    """Extract slice with dynamic start indices."""

def dynamic_slice_in_dim(operand, start, size, axis=0) -> Array:
    """Dynamic slice along single dimension."""

def dynamic_update_slice(operand, update, start_indices) -> Array:
    """Update slice with dynamic start indices."""

def dynamic_update_slice_in_dim(operand, update, start, axis) -> Array:
    """Dynamic update slice along single dimension."""

def gather(
    operand, 
    start_indices, 
    dimension_numbers,
    slice_sizes,
    indices_are_sorted=False,
    unique_indices=False,
    mode=None,
    fill_value=None
) -> Array:
    """General gather operation for advanced indexing."""

def scatter(
    operand, 
    scatter_indices, 
    updates, 
    dimension_numbers,
    indices_are_sorted=False,
    unique_indices=False,
    mode=None
) -> Array:
    """General scatter operation for advanced updates."""

# Scatter variants for different operations
def scatter_add(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_sub(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_mul(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_max(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...
def scatter_min(operand, scatter_indices, updates, dimension_numbers, **kwargs) -> Array: ...

def index_in_dim(operand, index, axis=0, keepdims=True) -> Array:
    """Index array along single dimension."""

def index_take(src, idxs, axes) -> Array:
    """Take elements using multi-dimensional indices."""

Reduction Operations

Reduce arrays along specified axes using various operations.

def reduce(
    operand, 
    init_value, 
    computation, 
    dimensions
) -> Array:
    """
    General reduction operation.
    
    Args:
        operand: Array to reduce
        init_value: Initial value for reduction
        computation: Binary function for reduction
        dimensions: Axes to reduce over
        
    Returns:
        Reduced array
    """

# Specialized reductions
def reduce_sum(operand, axes) -> Array: ...
def reduce_prod(operand, axes) -> Array: ...
def reduce_max(operand, axes) -> Array: ...
def reduce_min(operand, axes) -> Array: ...
def reduce_and(operand, axes) -> Array: ...
def reduce_or(operand, axes) -> Array: ...
def reduce_xor(operand, axes) -> Array: ...

# Windowed reductions
def reduce_window(
    operand,
    init_value,
    computation,
    window_dimensions,
    window_strides=None,
    padding=None,
    base_dilation=None,
    window_dilation=None
) -> Array:
    """
    Sliding window reduction.
    
    Args:
        operand: Input array
        init_value: Initial value for reduction
        computation: Binary reduction function
        window_dimensions: Size of sliding window
        window_strides: Stride of sliding window
        padding: Padding specification
        base_dilation: Base dilation factor
        window_dilation: Window dilation factor
        
    Returns:
        Reduced array with window operation applied
    """

Control Flow

Conditional execution and loop constructs for dynamic computation graphs.

def cond(pred, true_fun, false_fun, *operands) -> Any:
    """
    Conditional execution based on predicate.
    
    Args:
        pred: Boolean scalar predicate
        true_fun: Function to execute if pred is True
        false_fun: Function to execute if pred is False
        operands: Arguments to pass to selected function
        
    Returns:
        Result of executing selected function
    """

def select(pred, on_true, on_false) -> Array:
    """Element-wise conditional selection."""

def select_n(which, *cases) -> Array:
    """Multi-way conditional selection."""

def while_loop(cond_fun, body_fun, init_val) -> Any:
    """
    While loop with condition and body functions.
    
    Args:
        cond_fun: Function that returns boolean condition
        body_fun: Function that updates loop state
        init_val: Initial loop state
        
    Returns:
        Final loop state after termination
    """

def fori_loop(lower, upper, body_fun, init_val) -> Any:
    """
    For loop over range with body function.
    
    Args:
        lower: Loop start index
        upper: Loop end index (exclusive)
        body_fun: Function that updates state (takes index and state)
        init_val: Initial loop state
        
    Returns:
        Final loop state
    """

def scan(f, init, xs, length=None, reverse=False, unroll=1) -> tuple[Any, Array]:
    """
    Scan operation applying function over sequence.
    
    Args:
        f: Function to apply (takes carry and input, returns new carry and output)
        init: Initial carry value
        xs: Input sequence
        length: Length of sequence (inferred if None)
        reverse: Whether to scan in reverse
        unroll: Number of iterations to unroll
        
    Returns:
        Tuple of (final_carry, outputs)
    """

def associative_scan(fn, elems, reverse=False, axis=0) -> Array:
    """
    Parallel associative scan operation.
    
    Args:
        fn: Associative binary function
        elems: Input sequence
        reverse: Whether to scan in reverse
        axis: Axis to scan along
        
    Returns:
        Scanned results
    """

def switch(index, branches, *operands) -> Any:
    """
    Switch statement for multi-way branching.
    
    Args:
        index: Integer index selecting branch
        branches: List of functions (branches)
        operands: Arguments to pass to selected branch
        
    Returns:
        Result of executing selected branch
    """

def map(f, xs) -> Array:
    """Map function over leading axis of array."""

Cumulative Operations

Cumulative operations along array axes.

def cumsum(operand, axis=None, reverse=False) -> Array:
    """Cumulative sum along axis."""

def cumprod(operand, axis=None, reverse=False) -> Array:
    """Cumulative product along axis."""

def cummax(operand, axis=None, reverse=False) -> Array:
    """Cumulative maximum along axis."""

def cummin(operand, axis=None, reverse=False) -> Array:
    """Cumulative minimum along axis."""

def cumlogsumexp(operand, axis=None, reverse=False) -> Array:
    """Cumulative log-sum-exp along axis."""

Linear Algebra

Matrix operations and linear algebra primitives.

def dot(lhs, rhs, precision=None, preferred_element_type=None) -> Array:
    """Matrix multiplication for 1D and 2D arrays."""

def dot_general(
    lhs, 
    rhs, 
    dimension_numbers, 
    precision=None,
    preferred_element_type=None
) -> Array:
    """
    General matrix multiplication with custom contractions.
    
    Args:
        lhs: Left-hand side array
        rhs: Right-hand side array  
        dimension_numbers: Specification of contraction and batch dimensions
        precision: Computation precision
        preferred_element_type: Preferred output element type
        
    Returns:
        Result of general matrix multiplication
    """

def batch_matmul(
    lhs, 
    rhs, 
    precision=None,
    preferred_element_type=None
) -> Array:
    """Batched matrix multiplication."""

class DotDimensionNumbers:
    """Dimension specification for dot_general operation."""
    lhs_contracting_dimensions: tuple[int, ...]
    rhs_contracting_dimensions: tuple[int, ...]
    lhs_batch_dimensions: tuple[int, ...]
    rhs_batch_dimensions: tuple[int, ...]

Advanced Linear Algebra (lax.linalg)

Advanced linear algebra operations from jax.lax.linalg.

def cholesky(a, *, symmetrize_input: bool = True) -> Array:
    """
    Cholesky decomposition of positive definite matrix.
    
    Args:
        a: Positive definite matrix
        symmetrize_input: Whether to symmetrize input
        
    Returns:
        Lower triangular Cholesky factor
    """

def cholesky_update(r, u, *, alpha: float = 1.0) -> Array:
    """
    Rank-1 update to Cholesky factorization.
    
    Args:
        r: Cholesky factor
        u: Update vector
        alpha: Update coefficient
        
    Returns:
        Updated Cholesky factor
    """

def eig(a, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> tuple[Array, Array, Array]:
    """
    Eigenvalue decomposition of general matrix.
    
    Args:
        a: Input matrix
        compute_left_eigenvectors: Whether to compute left eigenvectors
        compute_right_eigenvectors: Whether to compute right eigenvectors
        
    Returns:
        Tuple of (eigenvalues, left_eigenvectors, right_eigenvectors)
    """

def eigh(a, *, lower: bool = True, symmetrize_input: bool = True, sort_eigenvalues: bool = True) -> tuple[Array, Array]:
    """
    Eigenvalue decomposition of Hermitian matrix.
    
    Args:
        a: Hermitian matrix
        lower: Whether to use lower triangle
        symmetrize_input: Whether to symmetrize input
        sort_eigenvalues: Whether to sort eigenvalues
        
    Returns:
        Tuple of (eigenvalues, eigenvectors)
    """

def lu(a) -> tuple[Array, Array, Array]:
    """
    LU decomposition with partial pivoting.
    
    Args:
        a: Input matrix
        
    Returns:
        Tuple of (lu_factors, pivots, permutation)
    """

def qr(a, *, full_matrices: bool = True) -> tuple[Array, Array]:
    """
    QR decomposition.
    
    Args:
        a: Input matrix
        full_matrices: Whether to return full or reduced QR
        
    Returns:
        Tuple of (q, r) matrices
    """

def svd(a, *, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False) -> tuple[Array, Array, Array]:
    """
    Singular value decomposition.
    
    Args:
        a: Input matrix
        full_matrices: Whether to return full or reduced SVD
        compute_uv: Whether to compute U and V matrices
        hermitian: Whether matrix is Hermitian
        
    Returns:
        Tuple of (u, s, vh) where A = U @ diag(s) @ Vh
    """

def schur(a, *, compute_schur_vectors: bool = True, sort_eigs: bool = False, select_callable=None) -> tuple[Array, Array]:
    """
    Schur decomposition.
    
    Args:
        a: Input matrix
        compute_schur_vectors: Whether to compute Schur vectors
        sort_eigs: Whether to sort eigenvalues
        select_callable: Selection function for eigenvalues
        
    Returns:
        Tuple of (schur_form, schur_vectors)
    """

def hessenberg(a) -> tuple[Array, Array]:
    """
    Hessenberg decomposition.
    
    Args:
        a: Input matrix
        
    Returns:
        Tuple of (hessenberg_form, orthogonal_matrix)
    """

def triangular_solve(a, b, *, left_side: bool = True, lower: bool = True, transpose_a: bool = False, conjugate_a: bool = False, unit_diagonal: bool = False) -> Array:
    """
    Solve triangular system of equations.
    
    Args:
        a: Triangular matrix
        b: Right-hand side
        left_side: Whether A is on left side (Ax = b) or right (xA = b)
        lower: Whether A is lower triangular
        transpose_a: Whether to transpose A
        conjugate_a: Whether to conjugate A
        unit_diagonal: Whether A has unit diagonal
        
    Returns:
        Solution to triangular system
    """

def tridiagonal(a, d, *, lower: bool = True) -> tuple[Array, Array]:
    """
    Tridiagonal reduction of symmetric matrix.
    
    Args:
        a: Symmetric matrix
        d: Diagonal elements
        lower: Whether to use lower triangle
        
    Returns:
        Tuple of (tridiagonal_matrix, orthogonal_matrix)
    """

def tridiagonal_solve(dl, d, du, b) -> Array:
    """
    Solve tridiagonal system using Thomas algorithm.
    
    Args:
        dl: Lower diagonal
        d: Main diagonal
        du: Upper diagonal
        b: Right-hand side
        
    Returns:
        Solution to tridiagonal system
    """

def qdwh(a, *, is_hermitian: bool = False, max_iterations: int = None, dynamic_shape: bool = False) -> tuple[Array, Array]:
    """
    QDWH polar decomposition: A = UP where U is unitary, P is positive semidefinite.
    
    Args:
        a: Input matrix
        is_hermitian: Whether matrix is Hermitian
        max_iterations: Maximum number of iterations
        dynamic_shape: Whether to handle dynamic shapes
        
    Returns:
        Tuple of (unitary_factor, positive_factor)
    """

def householder_product(a, taus) -> Array:
    """
    Compute product of Householder reflectors.
    
    Args:
        a: Matrix containing Householder vectors
        taus: Householder scaling factors
        
    Returns:
        Product of Householder reflectors
    """

def lu_pivots_to_permutation(pivots, permutation_size) -> Array:
    """
    Convert LU pivots to permutation matrix.
    
    Args:
        pivots: Pivot indices from LU decomposition
        permutation_size: Size of permutation matrix
        
    Returns:
        Permutation matrix
    """

Convolution Operations

Convolution operations for neural networks and signal processing.

def conv(
    lhs,
    rhs,
    window_strides,
    padding, 
    precision=None,
    preferred_element_type=None
) -> Array:
    """Basic convolution operation."""

def conv_general_dilated(
    lhs,
    rhs,
    window_strides,
    padding,
    lhs_dilation=None,
    rhs_dilation=None,
    dimension_numbers=None,
    feature_group_count=1,
    batch_group_count=1,
    precision=None,
    preferred_element_type=None
) -> Array:
    """
    General dilated convolution with full configuration options.
    
    Args:
        lhs: Input array (N...HWC or NCHW... format)
        rhs: Kernel array  
        window_strides: Convolution strides
        padding: Padding specification
        lhs_dilation: Input dilation
        rhs_dilation: Kernel dilation (atrous convolution)
        dimension_numbers: Dimension layout specification
        feature_group_count: Number of feature groups
        batch_group_count: Number of batch groups
        precision: Computation precision
        preferred_element_type: Preferred output type
        
    Returns:
        Convolution result
    """

def conv_transpose(
    lhs,
    rhs,
    strides,
    padding,
    rhs_dilation=None,
    dimension_numbers=None,
    transpose_kernel=False,
    precision=None,
    preferred_element_type=None
) -> Array:
    """Transposed (deconvolution) operation."""

class ConvDimensionNumbers:
    """Convolution dimension number specification."""
    lhs_spec: tuple[int, ...]  # Input dimension specification
    rhs_spec: tuple[int, ...]  # Kernel dimension specification  
    out_spec: tuple[int, ...]  # Output dimension specification

FFT Operations

Fast Fourier Transform operations.

def fft(a, fft_type, fft_lengths) -> Array:
    """
    Fast Fourier Transform.
    
    Args:
        a: Input array
        fft_type: Type of FFT (from FftType enum)
        fft_lengths: Lengths of FFT dimensions
        
    Returns:
        FFT result
    """

class FftType:
    """FFT type enumeration."""
    FFT = "FFT"
    IFFT = "IFFT"  
    RFFT = "RFFT"
    IRFFT = "IRFFT"

Parallel Operations

Multi-device communication primitives for distributed computing.

def all_gather(x, axis_name, *, axis_index_groups=None, tiled=False) -> Array:
    """Gather values from all devices."""

def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False) -> Array:
    """All-to-all communication between devices."""

def psum(x, axis_name, *, axis_index_groups=None) -> Array:
    """Parallel sum reduction across devices."""

def pmean(x, axis_name, *, axis_index_groups=None) -> Array:
    """Parallel mean reduction across devices."""

def pmax(x, axis_name, *, axis_index_groups=None) -> Array:
    """Parallel max reduction across devices."""

def pmin(x, axis_name, *, axis_index_groups=None) -> Array:
    """Parallel min reduction across devices."""

def ppermute(x, axis_name, perm, *, axis_index_groups=None) -> Array:
    """Permute data between devices."""

def axis_index(axis_name) -> Array:
    """Get device index along named axis."""

def axis_size(axis_name) -> int:
    """Get number of devices along named axis."""

def pbroadcast(x, axis_name, *, axis_index_groups=None) -> Array:
    """Broadcast from first device to all others."""

Special Functions

Special mathematical functions and probability distributions.

# Error functions
def erf(x) -> Array: ...
def erfc(x) -> Array: ...
def erf_inv(x) -> Array: ...

# Gamma functions  
def lgamma(x) -> Array: ...
def digamma(x) -> Array: ...
def polygamma(m, x) -> Array: ...

# Bessel functions
def bessel_i0e(x) -> Array: ...
def bessel_i1e(x) -> Array: ...

# Other special functions
def betainc(a, b, x) -> Array: ...
def igamma(a, x) -> Array: ...
def igammac(a, x) -> Array: ...
def zeta(x, q=None) -> Array: ...

Type Conversion and Manipulation

Array type conversion and data manipulation operations.

def convert_element_type(operand, new_dtype) -> Array:
    """Convert array element type."""

def bitcast_convert_type(operand, new_dtype) -> Array:
    """Bitcast array to new type without changing bit representation."""

def dtype(x) -> numpy.dtype:
    """Get array data type."""

def full(shape, fill_value, dtype=None) -> Array:
    """Create array filled with constant value."""

def full_like(x, fill_value, dtype=None, shape=None) -> Array:
    """Create filled array with same properties as input."""

def iota(dtype, size) -> Array:
    """Create array with sequential values (0, 1, 2, ...)."""

def broadcasted_iota(dtype, shape, dimension) -> Array:
    """Create iota array broadcasted to shape."""

Sorting Operations

Sorting and selection operations.

def sort(operand, dimension=-1, is_stable=True) -> Array:
    """Sort array along dimension."""

def sort_key_val(keys, values, dimension=-1, is_stable=True) -> tuple[Array, Array]:
    """Sort key-value pairs."""

def top_k(operand, k) -> tuple[Array, Array]:
    """Find top k largest elements and their indices."""

def argmax(operand, axis=None, index_dtype=int) -> Array:
    """Indices of maximum values."""

def argmin(operand, axis=None, index_dtype=int) -> Array:
    """Indices of minimum values."""

Miscellaneous Operations

Additional utility operations and performance primitives.

def stop_gradient(x) -> Array:
    """Stop gradient computation at this point."""

def optimization_barrier(x) -> Array:
    """Prevent optimization across this point."""

def nextafter(x1, x2) -> Array:
    """Next representable value after x1 in direction of x2."""

def reduce_precision(operand, exponent_bits, mantissa_bits) -> Array:
    """Reduce floating-point precision."""

def create_token() -> Array:
    """Create execution token for ordering side effects."""

def after_all(*tokens) -> Array:
    """Create token that depends on all input tokens."""

# Random number generation primitives
def rng_uniform(a, b, shape, dtype=None) -> Array:
    """Low-level uniform random number generation."""

def rng_bit_generator(key, shape, dtype=None, algorithm=None) -> tuple[Array, Array]:
    """Low-level random bit generation."""

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json