CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

scipy-compatibility.mddocs/

SciPy Compatibility

JAX provides SciPy-compatible functions through jax.scipy for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. These functions are differentiable and can be JIT-compiled.

Core Imports

import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jst

Capabilities

Linear Algebra (jax.scipy.linalg)

Advanced linear algebra operations for matrix computations and decompositions.

# Matrix decompositions
def cholesky(a, lower=True) -> Array:
    """
    Cholesky decomposition of positive definite matrix.
    
    Args:
        a: Positive definite matrix to decompose
        lower: Whether to return lower triangular factor
        
    Returns:
        Cholesky factor L such that a = L @ L.T (or U.T @ U if upper)
    """

def qr(a, mode='reduced') -> tuple[Array, Array]:
    """
    QR decomposition of matrix.
    
    Args:
        a: Matrix to decompose
        mode: 'reduced' or 'complete' decomposition
        
    Returns:
        Tuple (Q, R) where Q is orthogonal and R is upper triangular
    """

def svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]:
    """
    Singular Value Decomposition.
    
    Args:
        a: Matrix to decompose
        full_matrices: Whether to compute full or reduced SVD
        compute_uv: Whether to compute U and V matrices
        hermitian: Whether matrix is Hermitian
        
    Returns:
        Tuple (U, s, Vh) where a = U @ diag(s) @ Vh
    """

def eig(a, b=None, left=False, right=True, overwrite_a=False, overwrite_b=False, 
        check_finite=True, homogeneous_eigvals=False) -> tuple[Array, Array]:
    """
    Eigenvalues and eigenvectors of general matrix.
    
    Args:
        a: Square matrix
        b: Optional matrix for generalized eigenvalue problem
        left: Whether to compute left eigenvectors
        right: Whether to compute right eigenvectors
        overwrite_a: Whether input can be overwritten
        overwrite_b: Whether b can be overwritten  
        check_finite: Whether to check for finite values
        homogeneous_eigvals: Whether to return homogeneous eigenvalues
        
    Returns:
        Tuple (eigenvalues, eigenvectors)
    """

def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
         overwrite_b=False, turbo=True, eigvals=None, type=1,
         check_finite=True) -> tuple[Array, Array]:
    """
    Eigenvalues and eigenvectors of Hermitian matrix.
    
    Args:
        a: Hermitian matrix
        b: Optional matrix for generalized problem
        lower: Whether to use lower triangle
        eigvals_only: Whether to compute eigenvalues only
        overwrite_a: Whether input can be overwritten
        overwrite_b: Whether b can be overwritten
        turbo: Whether to use turbo algorithm
        eigvals: Range of eigenvalue indices to compute
        type: Type of generalized eigenvalue problem
        check_finite: Whether to check for finite values
        
    Returns:
        Eigenvalues (and eigenvectors if eigvals_only=False)
    """

def eigvals(a, b=None, overwrite_a=False, check_finite=True, 
           homogeneous_eigvals=False) -> Array:
    """Eigenvalues of general matrix."""

def eigvalsh(a, b=None, lower=True, overwrite_a=False, overwrite_b=False,
            turbo=True, eigvals=None, type=1, check_finite=True) -> Array:
    """Eigenvalues of Hermitian matrix."""

# Matrix properties and functions
def det(a) -> Array:
    """Matrix determinant."""

def slogdet(a) -> tuple[Array, Array]:
    """Sign and log determinant of matrix."""

def logdet(a) -> Array:
    """Log determinant of matrix."""

def matrix_rank(M, tol=None, hermitian=False) -> Array:
    """Matrix rank computation."""

def trace(a, offset=0, axis1=0, axis2=1) -> Array:
    """Matrix trace."""

def norm(a, ord=None, axis=None, keepdims=False) -> Array:
    """Matrix or vector norm."""

def cond(x, p=None) -> Array:
    """Condition number of matrix."""

# Matrix solutions
def solve(a, b, assume_a='gen', lower=False, overwrite_a=False, 
          overwrite_b=False, debug=None, check_finite=True) -> Array:
    """
    Solve linear system Ax = b.
    
    Args:
        a: Coefficient matrix
        b: Right-hand side vector/matrix
        assume_a: Properties of matrix a ('gen', 'sym', 'her', 'pos')
        lower: Whether to use lower triangle for triangular matrices
        overwrite_a: Whether input can be overwritten
        overwrite_b: Whether b can be overwritten
        debug: Debug information level
        check_finite: Whether to check for finite values
        
    Returns:
        Solution x such that Ax = b
    """

def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
                    overwrite_b=False, debug=None, check_finite=True) -> Array:
    """Solve triangular linear system."""

def inv(a, overwrite_a=False, check_finite=True) -> Array:
    """Matrix inverse."""

def pinv(a, rcond=None, hermitian=False, return_rank=False) -> Array:
    """Moore-Penrose pseudoinverse."""

def lstsq(a, b, rcond=None, lapack_driver=None) -> tuple[Array, Array, Array, Array]:
    """
    Least-squares solution to linear system.
    
    Args:
        a: Coefficient matrix
        b: Dependent variable values
        rcond: Cutoff ratio for small singular values
        lapack_driver: LAPACK driver to use
        
    Returns:
        Tuple (solution, residuals, rank, singular_values)
    """

# Matrix functions
def expm(A) -> Array:
    """Matrix exponential."""

def funm(A, func, disp=True) -> Array:
    """General matrix function evaluation."""

def sqrtm(A, disp=True, blocksize=64) -> Array:
    """Matrix square root."""

def logm(A, disp=True) -> Array:
    """Matrix logarithm."""

def fractional_matrix_power(A, t) -> Array:
    """Fractional matrix power A^t."""

def matrix_power(A, n) -> Array:
    """Integer matrix power A^n."""

# Schur decomposition
def schur(a, output='real') -> tuple[Array, Array]:
    """Schur decomposition of matrix."""

def rsf2csf(T, Z) -> tuple[Array, Array]:
    """Convert real Schur form to complex Schur form."""

# Polar decomposition  
def polar(a, side='right') -> tuple[Array, Array]:
    """Polar decomposition of matrix."""

Special Functions (jax.scipy.special)

Special mathematical functions including error functions, gamma functions, and Bessel functions.

# Error functions
def erf(z) -> Array:
    """Error function."""

def erfc(x) -> Array:
    """Complementary error function."""

def erfinv(y) -> Array:
    """Inverse error function."""

def erfcinv(y) -> Array:
    """Inverse complementary error function."""

def wofz(z) -> Array:
    """Faddeeva function."""

# Gamma functions
def gamma(z) -> Array:
    """Gamma function."""

def gammaln(x) -> Array:
    """Log gamma function."""

def digamma(x) -> Array:
    """Digamma (psi) function."""

def polygamma(n, x) -> Array:
    """Polygamma function."""

def gammainc(a, x) -> Array:
    """Lower incomplete gamma function."""

def gammaincc(a, x) -> Array:
    """Upper incomplete gamma function."""

def gammasgn(x) -> Array:
    """Sign of gamma function."""

def rgamma(x) -> Array:
    """Reciprocal gamma function."""

# Beta functions
def beta(a, b) -> Array:
    """Beta function."""

def betaln(a, b) -> Array:
    """Log beta function."""

def betainc(a, b, x) -> Array:
    """Incomplete beta function."""

# Bessel functions
def j0(x) -> Array:
    """Bessel function of the first kind of order 0."""

def j1(x) -> Array:
    """Bessel function of the first kind of order 1."""

def jn(n, x) -> Array:
    """Bessel function of the first kind of order n."""

def y0(x) -> Array:
    """Bessel function of the second kind of order 0."""

def y1(x) -> Array:
    """Bessel function of the second kind of order 1."""

def yn(n, x) -> Array:
    """Bessel function of the second kind of order n."""

def i0(x) -> Array:
    """Modified Bessel function of the first kind of order 0."""

def i0e(x) -> Array:
    """Exponentially scaled modified Bessel function i0."""

def i1(x) -> Array:
    """Modified Bessel function of the first kind of order 1."""

def i1e(x) -> Array:
    """Exponentially scaled modified Bessel function i1."""

def iv(v, z) -> Array:
    """Modified Bessel function of the first kind of real order."""

def k0(x) -> Array:
    """Modified Bessel function of the second kind of order 0."""

def k0e(x) -> Array:
    """Exponentially scaled modified Bessel function k0."""

def k1(x) -> Array:
    """Modified Bessel function of the second kind of order 1."""

def k1e(x) -> Array:
    """Exponentially scaled modified Bessel function k1."""

def kv(v, z) -> Array:
    """Modified Bessel function of the second kind of real order."""

# Exponential integrals
def expi(x) -> Array:
    """Exponential integral Ei."""

def expn(n, x) -> Array:
    """Generalized exponential integral."""

# Log-sum-exp and related  
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False) -> Array:
    """
    Compute log(sum(exp(a))) in numerically stable way.
    
    Args:
        a: Input array
        axis: Axis to sum over
        b: Multiplier for each element
        keepdims: Whether to keep reduced dimensions
        return_sign: Whether to return sign separately
        
    Returns:
        Log-sum-exp result
    """

def softmax(x, axis=None) -> Array:
    """Softmax function."""

def log_softmax(x, axis=None) -> Array:
    """Log softmax function."""

# Combinatorial functions
def factorial(n, exact=False) -> Array:
    """Factorial function."""

def factorial2(n, exact=False) -> Array:
    """Double factorial function."""

def factorialk(n, k, exact=False) -> Array:
    """Multifactorial function."""

def comb(N, k, exact=False, repetition=False) -> Array:
    """Binomial coefficient."""

def perm(N, k, exact=False) -> Array:
    """Permutation coefficient."""

# Elliptic integrals
def ellipk(m) -> Array:
    """Complete elliptic integral of the first kind."""

def ellipe(m) -> Array:
    """Complete elliptic integral of the second kind."""

def ellipkinc(phi, m) -> Array:
    """Incomplete elliptic integral of the first kind."""

def ellipeinc(phi, m) -> Array:
    """Incomplete elliptic integral of the second kind."""

# Zeta and related functions
def zeta(x, q=None) -> Array:
    """Riemann or Hurwitz zeta function."""

def zetac(x) -> Array:
    """Riemann zeta function minus 1."""

# Hypergeometric functions
def hyp1f1(a, b, x) -> Array:
    """Confluent hypergeometric function 1F1."""

def hyp2f1(a, b, c, z) -> Array:
    """Gaussian hypergeometric function 2F1."""

def hyperu(a, b, x) -> Array:
    """Confluent hypergeometric function U."""

# Legendre functions
def legendre(n, x) -> Array:
    """Legendre polynomial."""

def lpmv(m, v, x) -> Array:
    """Associated Legendre function."""

# Spherical functions  
def sph_harm(m, n, theta, phi) -> Array:
    """Spherical harmonics."""

# Other special functions
def lambertw(z, k=0, tol=1e-8) -> Array:
    """Lambert W function."""

def spence(z) -> Array:
    """Spence function."""

def multigammaln(a, d) -> Array:
    """Log of multivariate gamma function."""

def entr(x) -> Array:
    """Elementwise function -x*log(x)."""

def kl_div(x, y) -> Array:
    """Elementwise function x*log(x/y) - x + y."""

def rel_entr(x, y) -> Array:
    """Elementwise function x*log(x/y)."""

def huber(delta, r) -> Array:
    """Huber loss function."""

def pseudo_huber(delta, r) -> Array:
    """Pseudo-Huber loss function."""

Statistics (jax.scipy.stats)

Statistical distributions and functions for probability and hypothesis testing.

# Continuous distributions
class norm:
    """Normal distribution."""
    @staticmethod
    def pdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod  
    def logpdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logcdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def sf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logsf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def ppf(q, loc=0, scale=1) -> Array: ...
    @staticmethod
    def isf(q, loc=0, scale=1) -> Array: ...

class multivariate_normal:
    """Multivariate normal distribution."""
    @staticmethod
    def pdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
    @staticmethod
    def logpdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...

class uniform:
    """Uniform distribution."""
    @staticmethod
    def pdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logcdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def sf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logsf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def ppf(q, loc=0, scale=1) -> Array: ...

class beta:
    """Beta distribution."""  
    @staticmethod
    def pdf(x, a, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, a, b, loc=0, scale=1) -> Array: ...

class gamma:
    """Gamma distribution."""
    @staticmethod
    def pdf(x, a, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, a, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, a, loc=0, scale=1) -> Array: ...

class chi2:
    """Chi-square distribution."""
    @staticmethod
    def pdf(x, df, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, df, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, df, loc=0, scale=1) -> Array: ...

class t:
    """Student's t-distribution."""
    @staticmethod
    def pdf(x, df, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, df, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, df, loc=0, scale=1) -> Array: ...

class f:
    """F-distribution."""
    @staticmethod  
    def pdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...

class laplace:
    """Laplace distribution."""
    @staticmethod
    def pdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, loc=0, scale=1) -> Array: ...

class logistic:
    """Logistic distribution."""
    @staticmethod
    def pdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, loc=0, scale=1) -> Array: ...

class pareto:
    """Pareto distribution."""
    @staticmethod
    def pdf(x, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, b, loc=0, scale=1) -> Array: ...

class expon:
    """Exponential distribution."""
    @staticmethod
    def pdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, loc=0, scale=1) -> Array: ...

class lognorm:
    """Log-normal distribution."""
    @staticmethod
    def pdf(x, s, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, s, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, s, loc=0, scale=1) -> Array: ...

class truncnorm:
    """Truncated normal distribution."""
    @staticmethod
    def pdf(x, a, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
    @staticmethod
    def cdf(x, a, b, loc=0, scale=1) -> Array: ...

# Discrete distributions
class bernoulli:
    """Bernoulli distribution."""
    @staticmethod
    def pmf(k, p, loc=0) -> Array: ...
    @staticmethod
    def logpmf(k, p, loc=0) -> Array: ...
    @staticmethod
    def cdf(k, p, loc=0) -> Array: ...

class binom:
    """Binomial distribution."""
    @staticmethod
    def pmf(k, n, p, loc=0) -> Array: ...
    @staticmethod
    def logpmf(k, n, p, loc=0) -> Array: ...
    @staticmethod
    def cdf(k, n, p, loc=0) -> Array: ...

class geom:
    """Geometric distribution."""
    @staticmethod
    def pmf(k, p, loc=0) -> Array: ...
    @staticmethod
    def logpmf(k, p, loc=0) -> Array: ...
    @staticmethod
    def cdf(k, p, loc=0) -> Array: ...

class nbinom:
    """Negative binomial distribution."""
    @staticmethod
    def pmf(k, n, p, loc=0) -> Array: ...
    @staticmethod
    def logpmf(k, n, p, loc=0) -> Array: ...
    @staticmethod
    def cdf(k, n, p, loc=0) -> Array: ...

class poisson:
    """Poisson distribution."""
    @staticmethod
    def pmf(k, mu, loc=0) -> Array: ...
    @staticmethod
    def logpmf(k, mu, loc=0) -> Array: ...
    @staticmethod
    def cdf(k, mu, loc=0) -> Array: ...

# Statistical functions
def mode(a, axis=0, nan_policy='propagate', keepdims=False) -> Array:
    """Mode of array values along axis."""

def rankdata(a, method='average', axis=None) -> Array:
    """Rank data along axis."""

def kendalltau(x, y, initial_lexsort=None, nan_policy='propagate', method='auto') -> tuple[Array, Array]:
    """Kendall's tau correlation coefficient."""

def pearsonr(x, y) -> tuple[Array, Array]:
    """Pearson correlation coefficient."""

def spearmanr(a, b=None, axis=0, nan_policy='propagate', alternative='two-sided') -> tuple[Array, Array]:
    """Spearman correlation coefficient."""

Signal Processing (jax.scipy.signal)

Signal processing functions for filtering, convolution, and spectral analysis.

def convolve(in1, in2, mode='full', method='auto') -> Array:
    """N-dimensional convolution."""

def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
    """2D convolution."""

def correlate(in1, in2, mode='full', method='auto') -> Array:
    """Cross-correlation of two arrays."""

def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
    """2D cross-correlation."""

def fftconvolve(in1, in2, mode='full', axes=None) -> Array:
    """FFT-based convolution."""

def oaconvolve(in1, in2, mode='full', axes=None) -> Array:
    """Overlap-add convolution."""

def lfilter(b, a, x, axis=-1, zi=None) -> Array:
    """Linear digital filter."""

def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None) -> Array:
    """Zero-phase digital filtering."""

def sosfilt(sos, x, axis=-1, zi=None) -> Array:
    """Filter using second-order sections."""

def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None) -> Array:
    """Zero-phase filtering with second-order sections."""

def hilbert(x, N=None, axis=-1) -> Array:
    """Hilbert transform."""

def hilbert2(x, N=None) -> Array:
    """2D Hilbert transform."""

def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True) -> Array:
    """Downsample signal by integer factor."""

def resample(x, num, t=None, axis=0, window=None, domain='time') -> Array:
    """Resample signal to new sample rate."""

def resample_poly(x, up, down, axis=0, window='kaiser', padtype='constant', cval=None) -> Array:
    """Resample using polyphase filtering."""

def upfirdn(h, x, up=1, down=1, axis=-1, mode='constant', cval=0) -> Array:
    """Upsample, FIR filter, and downsample."""

def periodogram(x, fs=1.0, window='boxcar', nfft=None, detrend='constant', 
               return_onesided=True, scaling='density', axis=-1) -> tuple[Array, Array]:
    """Periodogram power spectral density."""

def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
         detrend='constant', return_onesided=True, scaling='density', axis=-1,
         average='mean') -> tuple[Array, Array]:
    """Welch's method for power spectral density."""

def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
        detrend='constant', return_onesided=True, scaling='density', axis=-1,
        average='mean') -> tuple[Array, Array]:
    """Cross power spectral density."""

def coherence(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
             detrend='constant', axis=-1) -> tuple[Array, Array]:
    """Coherence between signals."""

def spectrogram(x, fs=1.0, window='tukey', nperseg=None, noverlap=None, nfft=None,
               detrend='constant', return_onesided=True, scaling='density', axis=-1,
               mode='psd') -> tuple[Array, Array, Array]:
    """Spectrogram using short-time Fourier transform."""

def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
        detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) -> tuple[Array, Array, Array]:
    """Short-time Fourier transform."""

def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
         input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2) -> tuple[Array, Array]:
    """Inverse short-time Fourier transform."""

def lombscargle(x, y, freqs, precenter=False, normalize=False) -> Array:
    """Lomb-Scargle periodogram."""

def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False) -> Array:
    """Remove linear trend from data."""

def find_peaks(x, height=None, threshold=None, distance=None, prominence=None,
              width=None, wlen=None, rel_height=0.5, plateau_size=None) -> tuple[Array, dict]:
    """Find peaks in 1D array."""

def peak_prominences(x, peaks, wlen=None) -> tuple[Array, Array, Array]:
    """Calculate peak prominences."""

def peak_widths(x, peaks, rel_height=0.5, prominence_data=None, wlen=None) -> tuple[Array, Array, Array, Array]:
    """Calculate peak widths."""

Other Submodules

# Fast Fourier Transform (jax.scipy.fft)
import jax.scipy.fft as jfft
# Same interface as jax.numpy.fft with additional functions

# N-dimensional image processing (jax.scipy.ndimage)  
import jax.scipy.ndimage as jnd
# Image filtering, morphology, and measurements

# Sparse matrix operations (jax.scipy.sparse)
import jax.scipy.sparse as jss
# Sparse matrix formats and operations

# Interpolation (jax.scipy.interpolate)
import jax.scipy.interpolate as jsi  
# 1D and multidimensional interpolation

# Clustering (jax.scipy.cluster)
import jax.scipy.cluster as jsc
# Hierarchical and k-means clustering

# Integration and ODE solving (jax.scipy.integrate)
import jax.scipy.integrate as jsi
# Numerical integration and differential equation solving

Usage Examples

import jax.numpy as jnp
import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jst

# Linear algebra example
A = jnp.array([[4.0, 2.0], [2.0, 3.0]])
b = jnp.array([1.0, 2.0])

# Solve linear system
x = jla.solve(A, b)

# Compute eigenvalues and eigenvectors
eigenvals, eigenvecs = jla.eigh(A)

# Matrix decomposition
L = jla.cholesky(A)  # A = L @ L.T

# Special functions
x = jnp.linspace(-3, 3, 100)
erf_vals = jss.erf(x)
gamma_vals = jss.gamma(x + 1)

# Statistical distributions  
data = jnp.array([1.2, 2.3, 1.8, 3.1, 2.7])
log_likelihood = jst.norm.logpdf(data, loc=2.0, scale=1.0).sum()

# Probability density functions
x_vals = jnp.linspace(0, 5, 100)
pdf_vals = jst.gamma.pdf(x_vals, a=2.0, scale=1.0)

# Use in optimization with JAX transformations
@jax.jit
def neg_log_likelihood(params, data):
    mu, sigma = params
    return -jst.norm.logpdf(data, mu, sigma).sum()

# Compute gradient for maximum likelihood estimation
grad_fn = jax.grad(neg_log_likelihood)
gradients = grad_fn([2.0, 1.0], data)

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json