Differentiate, compile, and transform Numpy code.
—
JAX uses a functional approach to pseudo-random number generation with explicit key management. This design enables reproducibility, parallelization, and vectorization while avoiding global state typical of other libraries.
import jax.random as jr
from jax.random import key, split, normal, uniformJAX random numbers require explicit key management:
Generate, split, and manipulate PRNG keys for deterministic random number generation.
def key(seed: int, impl=None) -> Array:
"""
Create a typed PRNG key from integer seed.
Args:
seed: Integer seed value
impl: PRNG implementation to use
Returns:
PRNG key array
"""
def PRNGKey(seed: int) -> Array:
"""
Create legacy PRNG key (uint32 format).
Args:
seed: Integer seed value
Returns:
Legacy format PRNG key
"""
def split(key: Array, num: int = 2) -> Array:
"""
Split PRNG key into multiple independent keys.
Args:
key: PRNG key to split
num: Number of keys to generate (default: 2)
Returns:
Array of shape (num,) + key.shape containing new keys
"""
def fold_in(key: Array, data: int) -> Array:
"""
Fold integer data into PRNG key.
Args:
key: PRNG key
data: Integer to fold into key
Returns:
New PRNG key with data folded in
"""
def clone(key: Array) -> Array:
"""
Clone PRNG key for reuse.
Args:
key: PRNG key to clone
Returns:
Cloned PRNG key
"""
def key_data(keys: Array) -> Array:
"""
Extract raw key data from PRNG keys.
Args:
keys: PRNG key array
Returns:
Raw key data
"""
def wrap_key_data(key_data: Array, *, impl=None) -> Array:
"""
Wrap raw key data as PRNG keys.
Args:
key_data: Raw key data
impl: PRNG implementation
Returns:
PRNG key array
"""
def key_impl(key: Array) -> str:
"""
Get PRNG implementation name for key.
Args:
key: PRNG key
Returns:
Implementation name string
"""Sample from continuous probability distributions.
def uniform(
key: Array,
shape=(),
dtype=float,
minval=0.0,
maxval=1.0
) -> Array:
"""
Sample from uniform distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
minval: Minimum value (inclusive)
maxval: Maximum value (exclusive)
Returns:
Random samples from uniform distribution
"""
def normal(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from standard normal (Gaussian) distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from N(0, 1)
"""
def multivariate_normal(
key: Array,
mean: Array,
cov: Array,
shape=(),
dtype=float,
method='cholesky'
) -> Array:
"""
Sample from multivariate normal distribution.
Args:
key: PRNG key
mean: Mean vector
cov: Covariance matrix
shape: Batch shape
dtype: Output data type
method: Decomposition method ('cholesky', 'eigh', 'svd')
Returns:
Random samples from multivariate normal
"""
def truncated_normal(
key: Array,
lower: float,
upper: float,
shape=(),
dtype=float
) -> Array:
"""
Sample from truncated normal distribution.
Args:
key: PRNG key
lower: Lower truncation bound
upper: Upper truncation bound
shape: Output shape
dtype: Output data type
Returns:
Random samples from truncated normal
"""
def beta(key: Array, a: Array, b: Array, shape=(), dtype=float) -> Array:
"""
Sample from beta distribution.
Args:
key: PRNG key
a: Alpha parameter (concentration)
b: Beta parameter (concentration)
shape: Output shape
dtype: Output data type
Returns:
Random samples from Beta(a, b)
"""
def gamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
"""
Sample from gamma distribution.
Args:
key: PRNG key
a: Shape parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Gamma(a, 1)
"""
def exponential(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from exponential distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Exponential(1)
"""
def laplace(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from Laplace distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Laplace(0, 1)
"""
def logistic(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from logistic distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Logistic(0, 1)
"""
def lognormal(key: Array, sigma=1.0, shape=(), dtype=float) -> Array:
"""
Sample from log-normal distribution.
Args:
key: PRNG key
sigma: Standard deviation of underlying normal
shape: Output shape
dtype: Output data type
Returns:
Random samples from log-normal distribution
"""
def pareto(key: Array, b: Array, shape=(), dtype=float) -> Array:
"""
Sample from Pareto distribution.
Args:
key: PRNG key
b: Shape parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Pareto(b, 1)
"""
def cauchy(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from Cauchy distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Cauchy(0, 1)
"""
def double_sided_maxwell(
key: Array,
loc: Array,
scale: Array,
shape=(),
dtype=float
) -> Array:
"""
Sample from double-sided Maxwell distribution.
Args:
key: PRNG key
loc: Location parameter
scale: Scale parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from double-sided Maxwell
"""
def maxwell(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from Maxwell distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Maxwell distribution
"""
def rayleigh(key: Array, scale=1.0, shape=(), dtype=float) -> Array:
"""
Sample from Rayleigh distribution.
Args:
key: PRNG key
scale: Scale parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Rayleigh(scale)
"""
def wald(key: Array, mean: Array, shape=(), dtype=float) -> Array:
"""
Sample from Wald (Inverse Gaussian) distribution.
Args:
key: PRNG key
mean: Mean parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Wald distribution
"""
def weibull_min(
key: Array,
concentration: Array,
scale=1.0,
shape=(),
dtype=float
) -> Array:
"""
Sample from Weibull minimum distribution.
Args:
key: PRNG key
concentration: Shape parameter
scale: Scale parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Weibull minimum
"""
def gumbel(key: Array, shape=(), dtype=float) -> Array:
"""
Sample from Gumbel distribution.
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from Gumbel(0, 1)
"""
def chisquare(key: Array, df: Array, shape=(), dtype=float) -> Array:
"""
Sample from chi-square distribution.
Args:
key: PRNG key
df: Degrees of freedom
shape: Output shape
dtype: Output data type
Returns:
Random samples from chi-square(df)
"""
def dirichlet(
key: Array,
alpha: Array,
shape=(),
dtype=float
) -> Array:
"""
Sample from Dirichlet distribution.
Args:
key: PRNG key
alpha: Concentration parameters
shape: Batch shape
dtype: Output data type
Returns:
Random samples from Dirichlet(alpha)
"""
def f(key: Array, dfnum: Array, dfden: Array, shape=(), dtype=float) -> Array:
"""
Sample from F-distribution.
Args:
key: PRNG key
dfnum: Numerator degrees of freedom
dfden: Denominator degrees of freedom
shape: Output shape
dtype: Output data type
Returns:
Random samples from F-distribution
"""
def t(key: Array, df: Array, shape=(), dtype=float) -> Array:
"""
Sample from Student's t-distribution.
Args:
key: PRNG key
df: Degrees of freedom
shape: Output shape
dtype: Output data type
Returns:
Random samples from t-distribution
"""
def triangular(
key: Array,
left: Array,
mode: Array,
right: Array,
shape=(),
dtype=float
) -> Array:
"""
Sample from triangular distribution.
Args:
key: PRNG key
left: Left boundary
mode: Mode (peak) value
right: Right boundary
shape: Output shape
dtype: Output data type
Returns:
Random samples from triangular distribution
"""
def generalized_normal(
key: Array,
p: Array,
shape=(),
dtype=float
) -> Array:
"""
Sample from generalized normal distribution.
Args:
key: PRNG key
p: Shape parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from generalized normal
"""
def loggamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
"""
Sample log-gamma random variables.
Args:
key: PRNG key
a: Shape parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from log-gamma distribution
"""Sample from discrete probability distributions.
def bernoulli(key: Array, p=0.5, shape=(), dtype=int) -> Array:
"""
Sample from Bernoulli distribution.
Args:
key: PRNG key
p: Success probability
shape: Output shape
dtype: Output data type
Returns:
Random samples from Bernoulli(p)
"""
def binomial(key: Array, n: Array, p: Array, shape=(), dtype=int) -> Array:
"""
Sample from binomial distribution.
Args:
key: PRNG key
n: Number of trials
p: Success probability per trial
shape: Output shape
dtype: Output data type
Returns:
Random samples from Binomial(n, p)
"""
def categorical(
key: Array,
logits: Array,
axis=-1,
shape=None
) -> Array:
"""
Sample from categorical distribution.
Args:
key: PRNG key
logits: Log-probability array
axis: Axis over which to normalize
shape: Output shape
Returns:
Random categorical indices
"""
def choice(
key: Array,
a: int | Array,
shape=(),
replace=True,
p=None,
axis=0
) -> Array:
"""
Random choice from array elements.
Args:
key: PRNG key
a: Array to sample from or integer (range)
shape: Output shape
replace: Whether to sample with replacement
p: Probabilities for each element
axis: Axis to sample along
Returns:
Random samples from input array
"""
def geometric(key: Array, p: Array, shape=(), dtype=int) -> Array:
"""
Sample from geometric distribution.
Args:
key: PRNG key
p: Success probability
shape: Output shape
dtype: Output data type
Returns:
Random samples from Geometric(p)
"""
def poisson(key: Array, lam: Array, shape=(), dtype=int) -> Array:
"""
Sample from Poisson distribution.
Args:
key: PRNG key
lam: Rate parameter
shape: Output shape
dtype: Output data type
Returns:
Random samples from Poisson(lam)
"""
def multinomial(
key: Array,
n: Array,
pvals: Array,
shape=(),
dtype=int
) -> Array:
"""
Sample from multinomial distribution.
Args:
key: PRNG key
n: Number of trials
pvals: Probability values for each category
shape: Batch shape
dtype: Output data type
Returns:
Random samples from Multinomial(n, pvals)
"""
def randint(
key: Array,
minval: int,
maxval: int,
shape=(),
dtype=int
) -> Array:
"""
Sample random integers from [minval, maxval).
Args:
key: PRNG key
minval: Minimum value (inclusive)
maxval: Maximum value (exclusive)
shape: Output shape
dtype: Output data type
Returns:
Random integers in specified range
"""
def rademacher(key: Array, shape=(), dtype=int) -> Array:
"""
Sample from Rademacher distribution (±1 with equal probability).
Args:
key: PRNG key
shape: Output shape
dtype: Output data type
Returns:
Random samples from {-1, +1}
"""Special sampling functions for geometric shapes and structured sampling.
def ball(key: Array, d: int, p=2, shape=(), dtype=float) -> Array:
"""
Sample uniformly from d-dimensional unit ball.
Args:
key: PRNG key
d: Dimension of ball
p: Norm type (default: 2 for Euclidean)
shape: Batch shape
dtype: Output data type
Returns:
Random samples from unit ball
"""
def orthogonal(key: Array, n: int, shape=(), dtype=float) -> Array:
"""
Sample random orthogonal matrix.
Args:
key: PRNG key
n: Matrix dimension
shape: Batch shape
dtype: Output data type
Returns:
Random orthogonal matrix of size (n, n)
"""
def permutation(key: Array, x: int | Array, axis=0, independent=False) -> Array:
"""
Generate random permutation of array or integers.
Args:
key: PRNG key
x: Array to permute or integer (range)
axis: Axis to permute along
independent: Whether to permute each batch element independently
Returns:
Randomly permuted array
"""
def bits(key: Array, width=64, shape=(), dtype=None) -> Array:
"""
Generate random bits.
Args:
key: PRNG key
width: Number of bits per sample
shape: Output shape
dtype: Output data type
Returns:
Random bit patterns
"""Common patterns for JAX random number generation:
import jax
import jax.numpy as jnp
import jax.random as jr
# Create and split keys
main_key = jr.key(42)
key1, key2, key3 = jr.split(main_key, 3)
# Basic sampling
samples = jr.normal(key1, (1000,))
random_ints = jr.randint(key2, 0, 10, (100,))
# Batch sampling with same key
batch_samples = jr.normal(key3, (32, 784)) # 32 samples of 784 dims
# Different keys for each batch element
keys = jr.split(main_key, 32)
independent_samples = jax.vmap(
lambda k: jr.normal(k, (784,))
)(keys)
# Random choice and permutation
data = jnp.arange(100)
shuffled = jr.permutation(key1, data)
selected = jr.choice(key2, data, (10,), replace=False)
# Multivariate distributions
mean = jnp.zeros(5)
cov = jnp.eye(5)
mv_samples = jr.multivariate_normal(key1, mean, cov, (1000,))
# Discrete distributions
coin_flips = jr.bernoulli(key1, 0.6, (100,))
dice_rolls = jr.categorical(key2, jnp.log(jnp.ones(6) / 6), (100,))
# Using in neural network initialization
def init_layer_weights(key, input_dim, output_dim):
w_key, b_key = jr.split(key)
# Xavier/Glorot initialization
std = jnp.sqrt(2.0 / (input_dim + output_dim))
weights = jr.normal(w_key, (input_dim, output_dim)) * std
biases = jr.normal(b_key, (output_dim,)) * 0.01
return weights, biases
# Stochastic gradient descent with random batching
def get_random_batch(key, data, batch_size):
indices = jr.choice(key, len(data), (batch_size,), replace=False)
return data[indices]Install with Tessl CLI
npx tessl i tessl/pypi-jax