Differentiate, compile, and transform Numpy code.
—
JAX provides SciPy-compatible functions through jax.scipy for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. These functions are differentiable and can be JIT-compiled.
import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jstjax.scipy.linalg)Advanced linear algebra operations for matrix computations and decompositions.
# Matrix decompositions
def cholesky(a, lower=True) -> Array:
"""
Cholesky decomposition of positive definite matrix.
Args:
a: Positive definite matrix to decompose
lower: Whether to return lower triangular factor
Returns:
Cholesky factor L such that a = L @ L.T (or U.T @ U if upper)
"""
def qr(a, mode='reduced') -> tuple[Array, Array]:
"""
QR decomposition of matrix.
Args:
a: Matrix to decompose
mode: 'reduced' or 'complete' decomposition
Returns:
Tuple (Q, R) where Q is orthogonal and R is upper triangular
"""
def svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]:
"""
Singular Value Decomposition.
Args:
a: Matrix to decompose
full_matrices: Whether to compute full or reduced SVD
compute_uv: Whether to compute U and V matrices
hermitian: Whether matrix is Hermitian
Returns:
Tuple (U, s, Vh) where a = U @ diag(s) @ Vh
"""
def eig(a, b=None, left=False, right=True, overwrite_a=False, overwrite_b=False,
check_finite=True, homogeneous_eigvals=False) -> tuple[Array, Array]:
"""
Eigenvalues and eigenvectors of general matrix.
Args:
a: Square matrix
b: Optional matrix for generalized eigenvalue problem
left: Whether to compute left eigenvectors
right: Whether to compute right eigenvectors
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
check_finite: Whether to check for finite values
homogeneous_eigvals: Whether to return homogeneous eigenvalues
Returns:
Tuple (eigenvalues, eigenvectors)
"""
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, type=1,
check_finite=True) -> tuple[Array, Array]:
"""
Eigenvalues and eigenvectors of Hermitian matrix.
Args:
a: Hermitian matrix
b: Optional matrix for generalized problem
lower: Whether to use lower triangle
eigvals_only: Whether to compute eigenvalues only
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
turbo: Whether to use turbo algorithm
eigvals: Range of eigenvalue indices to compute
type: Type of generalized eigenvalue problem
check_finite: Whether to check for finite values
Returns:
Eigenvalues (and eigenvectors if eigvals_only=False)
"""
def eigvals(a, b=None, overwrite_a=False, check_finite=True,
homogeneous_eigvals=False) -> Array:
"""Eigenvalues of general matrix."""
def eigvalsh(a, b=None, lower=True, overwrite_a=False, overwrite_b=False,
turbo=True, eigvals=None, type=1, check_finite=True) -> Array:
"""Eigenvalues of Hermitian matrix."""
# Matrix properties and functions
def det(a) -> Array:
"""Matrix determinant."""
def slogdet(a) -> tuple[Array, Array]:
"""Sign and log determinant of matrix."""
def logdet(a) -> Array:
"""Log determinant of matrix."""
def matrix_rank(M, tol=None, hermitian=False) -> Array:
"""Matrix rank computation."""
def trace(a, offset=0, axis1=0, axis2=1) -> Array:
"""Matrix trace."""
def norm(a, ord=None, axis=None, keepdims=False) -> Array:
"""Matrix or vector norm."""
def cond(x, p=None) -> Array:
"""Condition number of matrix."""
# Matrix solutions
def solve(a, b, assume_a='gen', lower=False, overwrite_a=False,
overwrite_b=False, debug=None, check_finite=True) -> Array:
"""
Solve linear system Ax = b.
Args:
a: Coefficient matrix
b: Right-hand side vector/matrix
assume_a: Properties of matrix a ('gen', 'sym', 'her', 'pos')
lower: Whether to use lower triangle for triangular matrices
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
debug: Debug information level
check_finite: Whether to check for finite values
Returns:
Solution x such that Ax = b
"""
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True) -> Array:
"""Solve triangular linear system."""
def inv(a, overwrite_a=False, check_finite=True) -> Array:
"""Matrix inverse."""
def pinv(a, rcond=None, hermitian=False, return_rank=False) -> Array:
"""Moore-Penrose pseudoinverse."""
def lstsq(a, b, rcond=None, lapack_driver=None) -> tuple[Array, Array, Array, Array]:
"""
Least-squares solution to linear system.
Args:
a: Coefficient matrix
b: Dependent variable values
rcond: Cutoff ratio for small singular values
lapack_driver: LAPACK driver to use
Returns:
Tuple (solution, residuals, rank, singular_values)
"""
# Matrix functions
def expm(A) -> Array:
"""Matrix exponential."""
def funm(A, func, disp=True) -> Array:
"""General matrix function evaluation."""
def sqrtm(A, disp=True, blocksize=64) -> Array:
"""Matrix square root."""
def logm(A, disp=True) -> Array:
"""Matrix logarithm."""
def fractional_matrix_power(A, t) -> Array:
"""Fractional matrix power A^t."""
def matrix_power(A, n) -> Array:
"""Integer matrix power A^n."""
# Schur decomposition
def schur(a, output='real') -> tuple[Array, Array]:
"""Schur decomposition of matrix."""
def rsf2csf(T, Z) -> tuple[Array, Array]:
"""Convert real Schur form to complex Schur form."""
# Polar decomposition
def polar(a, side='right') -> tuple[Array, Array]:
"""Polar decomposition of matrix."""jax.scipy.special)Special mathematical functions including error functions, gamma functions, and Bessel functions.
# Error functions
def erf(z) -> Array:
"""Error function."""
def erfc(x) -> Array:
"""Complementary error function."""
def erfinv(y) -> Array:
"""Inverse error function."""
def erfcinv(y) -> Array:
"""Inverse complementary error function."""
def wofz(z) -> Array:
"""Faddeeva function."""
# Gamma functions
def gamma(z) -> Array:
"""Gamma function."""
def gammaln(x) -> Array:
"""Log gamma function."""
def digamma(x) -> Array:
"""Digamma (psi) function."""
def polygamma(n, x) -> Array:
"""Polygamma function."""
def gammainc(a, x) -> Array:
"""Lower incomplete gamma function."""
def gammaincc(a, x) -> Array:
"""Upper incomplete gamma function."""
def gammasgn(x) -> Array:
"""Sign of gamma function."""
def rgamma(x) -> Array:
"""Reciprocal gamma function."""
# Beta functions
def beta(a, b) -> Array:
"""Beta function."""
def betaln(a, b) -> Array:
"""Log beta function."""
def betainc(a, b, x) -> Array:
"""Incomplete beta function."""
# Bessel functions
def j0(x) -> Array:
"""Bessel function of the first kind of order 0."""
def j1(x) -> Array:
"""Bessel function of the first kind of order 1."""
def jn(n, x) -> Array:
"""Bessel function of the first kind of order n."""
def y0(x) -> Array:
"""Bessel function of the second kind of order 0."""
def y1(x) -> Array:
"""Bessel function of the second kind of order 1."""
def yn(n, x) -> Array:
"""Bessel function of the second kind of order n."""
def i0(x) -> Array:
"""Modified Bessel function of the first kind of order 0."""
def i0e(x) -> Array:
"""Exponentially scaled modified Bessel function i0."""
def i1(x) -> Array:
"""Modified Bessel function of the first kind of order 1."""
def i1e(x) -> Array:
"""Exponentially scaled modified Bessel function i1."""
def iv(v, z) -> Array:
"""Modified Bessel function of the first kind of real order."""
def k0(x) -> Array:
"""Modified Bessel function of the second kind of order 0."""
def k0e(x) -> Array:
"""Exponentially scaled modified Bessel function k0."""
def k1(x) -> Array:
"""Modified Bessel function of the second kind of order 1."""
def k1e(x) -> Array:
"""Exponentially scaled modified Bessel function k1."""
def kv(v, z) -> Array:
"""Modified Bessel function of the second kind of real order."""
# Exponential integrals
def expi(x) -> Array:
"""Exponential integral Ei."""
def expn(n, x) -> Array:
"""Generalized exponential integral."""
# Log-sum-exp and related
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False) -> Array:
"""
Compute log(sum(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to sum over
b: Multiplier for each element
keepdims: Whether to keep reduced dimensions
return_sign: Whether to return sign separately
Returns:
Log-sum-exp result
"""
def softmax(x, axis=None) -> Array:
"""Softmax function."""
def log_softmax(x, axis=None) -> Array:
"""Log softmax function."""
# Combinatorial functions
def factorial(n, exact=False) -> Array:
"""Factorial function."""
def factorial2(n, exact=False) -> Array:
"""Double factorial function."""
def factorialk(n, k, exact=False) -> Array:
"""Multifactorial function."""
def comb(N, k, exact=False, repetition=False) -> Array:
"""Binomial coefficient."""
def perm(N, k, exact=False) -> Array:
"""Permutation coefficient."""
# Elliptic integrals
def ellipk(m) -> Array:
"""Complete elliptic integral of the first kind."""
def ellipe(m) -> Array:
"""Complete elliptic integral of the second kind."""
def ellipkinc(phi, m) -> Array:
"""Incomplete elliptic integral of the first kind."""
def ellipeinc(phi, m) -> Array:
"""Incomplete elliptic integral of the second kind."""
# Zeta and related functions
def zeta(x, q=None) -> Array:
"""Riemann or Hurwitz zeta function."""
def zetac(x) -> Array:
"""Riemann zeta function minus 1."""
# Hypergeometric functions
def hyp1f1(a, b, x) -> Array:
"""Confluent hypergeometric function 1F1."""
def hyp2f1(a, b, c, z) -> Array:
"""Gaussian hypergeometric function 2F1."""
def hyperu(a, b, x) -> Array:
"""Confluent hypergeometric function U."""
# Legendre functions
def legendre(n, x) -> Array:
"""Legendre polynomial."""
def lpmv(m, v, x) -> Array:
"""Associated Legendre function."""
# Spherical functions
def sph_harm(m, n, theta, phi) -> Array:
"""Spherical harmonics."""
# Other special functions
def lambertw(z, k=0, tol=1e-8) -> Array:
"""Lambert W function."""
def spence(z) -> Array:
"""Spence function."""
def multigammaln(a, d) -> Array:
"""Log of multivariate gamma function."""
def entr(x) -> Array:
"""Elementwise function -x*log(x)."""
def kl_div(x, y) -> Array:
"""Elementwise function x*log(x/y) - x + y."""
def rel_entr(x, y) -> Array:
"""Elementwise function x*log(x/y)."""
def huber(delta, r) -> Array:
"""Huber loss function."""
def pseudo_huber(delta, r) -> Array:
"""Pseudo-Huber loss function."""jax.scipy.stats)Statistical distributions and functions for probability and hypothesis testing.
# Continuous distributions
class norm:
"""Normal distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logcdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def sf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logsf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def ppf(q, loc=0, scale=1) -> Array: ...
@staticmethod
def isf(q, loc=0, scale=1) -> Array: ...
class multivariate_normal:
"""Multivariate normal distribution."""
@staticmethod
def pdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
@staticmethod
def logpdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
class uniform:
"""Uniform distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logcdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def sf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logsf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def ppf(q, loc=0, scale=1) -> Array: ...
class beta:
"""Beta distribution."""
@staticmethod
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
class gamma:
"""Gamma distribution."""
@staticmethod
def pdf(x, a, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, loc=0, scale=1) -> Array: ...
class chi2:
"""Chi-square distribution."""
@staticmethod
def pdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, df, loc=0, scale=1) -> Array: ...
class t:
"""Student's t-distribution."""
@staticmethod
def pdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, df, loc=0, scale=1) -> Array: ...
class f:
"""F-distribution."""
@staticmethod
def pdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
class laplace:
"""Laplace distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class logistic:
"""Logistic distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class pareto:
"""Pareto distribution."""
@staticmethod
def pdf(x, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, b, loc=0, scale=1) -> Array: ...
class expon:
"""Exponential distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class lognorm:
"""Log-normal distribution."""
@staticmethod
def pdf(x, s, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, s, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, s, loc=0, scale=1) -> Array: ...
class truncnorm:
"""Truncated normal distribution."""
@staticmethod
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
# Discrete distributions
class bernoulli:
"""Bernoulli distribution."""
@staticmethod
def pmf(k, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, p, loc=0) -> Array: ...
class binom:
"""Binomial distribution."""
@staticmethod
def pmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, n, p, loc=0) -> Array: ...
class geom:
"""Geometric distribution."""
@staticmethod
def pmf(k, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, p, loc=0) -> Array: ...
class nbinom:
"""Negative binomial distribution."""
@staticmethod
def pmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, n, p, loc=0) -> Array: ...
class poisson:
"""Poisson distribution."""
@staticmethod
def pmf(k, mu, loc=0) -> Array: ...
@staticmethod
def logpmf(k, mu, loc=0) -> Array: ...
@staticmethod
def cdf(k, mu, loc=0) -> Array: ...
# Statistical functions
def mode(a, axis=0, nan_policy='propagate', keepdims=False) -> Array:
"""Mode of array values along axis."""
def rankdata(a, method='average', axis=None) -> Array:
"""Rank data along axis."""
def kendalltau(x, y, initial_lexsort=None, nan_policy='propagate', method='auto') -> tuple[Array, Array]:
"""Kendall's tau correlation coefficient."""
def pearsonr(x, y) -> tuple[Array, Array]:
"""Pearson correlation coefficient."""
def spearmanr(a, b=None, axis=0, nan_policy='propagate', alternative='two-sided') -> tuple[Array, Array]:
"""Spearman correlation coefficient."""jax.scipy.signal)Signal processing functions for filtering, convolution, and spectral analysis.
def convolve(in1, in2, mode='full', method='auto') -> Array:
"""N-dimensional convolution."""
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
"""2D convolution."""
def correlate(in1, in2, mode='full', method='auto') -> Array:
"""Cross-correlation of two arrays."""
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
"""2D cross-correlation."""
def fftconvolve(in1, in2, mode='full', axes=None) -> Array:
"""FFT-based convolution."""
def oaconvolve(in1, in2, mode='full', axes=None) -> Array:
"""Overlap-add convolution."""
def lfilter(b, a, x, axis=-1, zi=None) -> Array:
"""Linear digital filter."""
def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None) -> Array:
"""Zero-phase digital filtering."""
def sosfilt(sos, x, axis=-1, zi=None) -> Array:
"""Filter using second-order sections."""
def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None) -> Array:
"""Zero-phase filtering with second-order sections."""
def hilbert(x, N=None, axis=-1) -> Array:
"""Hilbert transform."""
def hilbert2(x, N=None) -> Array:
"""2D Hilbert transform."""
def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True) -> Array:
"""Downsample signal by integer factor."""
def resample(x, num, t=None, axis=0, window=None, domain='time') -> Array:
"""Resample signal to new sample rate."""
def resample_poly(x, up, down, axis=0, window='kaiser', padtype='constant', cval=None) -> Array:
"""Resample using polyphase filtering."""
def upfirdn(h, x, up=1, down=1, axis=-1, mode='constant', cval=0) -> Array:
"""Upsample, FIR filter, and downsample."""
def periodogram(x, fs=1.0, window='boxcar', nfft=None, detrend='constant',
return_onesided=True, scaling='density', axis=-1) -> tuple[Array, Array]:
"""Periodogram power spectral density."""
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
average='mean') -> tuple[Array, Array]:
"""Welch's method for power spectral density."""
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
average='mean') -> tuple[Array, Array]:
"""Cross power spectral density."""
def coherence(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', axis=-1) -> tuple[Array, Array]:
"""Coherence between signals."""
def spectrogram(x, fs=1.0, window='tukey', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
mode='psd') -> tuple[Array, Array, Array]:
"""Spectrogram using short-time Fourier transform."""
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) -> tuple[Array, Array, Array]:
"""Short-time Fourier transform."""
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2) -> tuple[Array, Array]:
"""Inverse short-time Fourier transform."""
def lombscargle(x, y, freqs, precenter=False, normalize=False) -> Array:
"""Lomb-Scargle periodogram."""
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False) -> Array:
"""Remove linear trend from data."""
def find_peaks(x, height=None, threshold=None, distance=None, prominence=None,
width=None, wlen=None, rel_height=0.5, plateau_size=None) -> tuple[Array, dict]:
"""Find peaks in 1D array."""
def peak_prominences(x, peaks, wlen=None) -> tuple[Array, Array, Array]:
"""Calculate peak prominences."""
def peak_widths(x, peaks, rel_height=0.5, prominence_data=None, wlen=None) -> tuple[Array, Array, Array, Array]:
"""Calculate peak widths."""# Fast Fourier Transform (jax.scipy.fft)
import jax.scipy.fft as jfft
# Same interface as jax.numpy.fft with additional functions
# N-dimensional image processing (jax.scipy.ndimage)
import jax.scipy.ndimage as jnd
# Image filtering, morphology, and measurements
# Sparse matrix operations (jax.scipy.sparse)
import jax.scipy.sparse as jss
# Sparse matrix formats and operations
# Interpolation (jax.scipy.interpolate)
import jax.scipy.interpolate as jsi
# 1D and multidimensional interpolation
# Clustering (jax.scipy.cluster)
import jax.scipy.cluster as jsc
# Hierarchical and k-means clustering
# Integration and ODE solving (jax.scipy.integrate)
import jax.scipy.integrate as jsi
# Numerical integration and differential equation solvingimport jax.numpy as jnp
import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jst
# Linear algebra example
A = jnp.array([[4.0, 2.0], [2.0, 3.0]])
b = jnp.array([1.0, 2.0])
# Solve linear system
x = jla.solve(A, b)
# Compute eigenvalues and eigenvectors
eigenvals, eigenvecs = jla.eigh(A)
# Matrix decomposition
L = jla.cholesky(A) # A = L @ L.T
# Special functions
x = jnp.linspace(-3, 3, 100)
erf_vals = jss.erf(x)
gamma_vals = jss.gamma(x + 1)
# Statistical distributions
data = jnp.array([1.2, 2.3, 1.8, 3.1, 2.7])
log_likelihood = jst.norm.logpdf(data, loc=2.0, scale=1.0).sum()
# Probability density functions
x_vals = jnp.linspace(0, 5, 100)
pdf_vals = jst.gamma.pdf(x_vals, a=2.0, scale=1.0)
# Use in optimization with JAX transformations
@jax.jit
def neg_log_likelihood(params, data):
mu, sigma = params
return -jst.norm.logpdf(data, mu, sigma).sum()
# Compute gradient for maximum likelihood estimation
grad_fn = jax.grad(neg_log_likelihood)
gradients = grad_fn([2.0, 1.0], data)Install with Tessl CLI
npx tessl i tessl/pypi-jax