CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

random-numbers.mddocs/

Random Number Generation

JAX uses a functional approach to pseudo-random number generation with explicit key management. This design enables reproducibility, parallelization, and vectorization while avoiding global state typical of other libraries.

Core Imports

import jax.random as jr
from jax.random import key, split, normal, uniform

Key Concepts

JAX random numbers require explicit key management:

  • Keys are created from integer seeds
  • Keys must be split to generate independent random sequences
  • Each random function consumes a key and returns deterministic output
  • No global random state - all randomness is explicit

Capabilities

Key Management

Generate, split, and manipulate PRNG keys for deterministic random number generation.

def key(seed: int, impl=None) -> Array:
    """
    Create a typed PRNG key from integer seed.
    
    Args:
        seed: Integer seed value
        impl: PRNG implementation to use
        
    Returns:
        PRNG key array
    """

def PRNGKey(seed: int) -> Array:
    """
    Create legacy PRNG key (uint32 format).
    
    Args:
        seed: Integer seed value
        
    Returns:
        Legacy format PRNG key
    """

def split(key: Array, num: int = 2) -> Array:
    """
    Split PRNG key into multiple independent keys.
    
    Args:
        key: PRNG key to split
        num: Number of keys to generate (default: 2)
        
    Returns:
        Array of shape (num,) + key.shape containing new keys
    """

def fold_in(key: Array, data: int) -> Array:
    """
    Fold integer data into PRNG key.
    
    Args:
        key: PRNG key
        data: Integer to fold into key
        
    Returns:
        New PRNG key with data folded in
    """

def clone(key: Array) -> Array:
    """
    Clone PRNG key for reuse.
    
    Args:
        key: PRNG key to clone
        
    Returns:
        Cloned PRNG key
    """

def key_data(keys: Array) -> Array:
    """
    Extract raw key data from PRNG keys.
    
    Args:
        keys: PRNG key array
        
    Returns:
        Raw key data
    """

def wrap_key_data(key_data: Array, *, impl=None) -> Array:
    """
    Wrap raw key data as PRNG keys.
    
    Args:
        key_data: Raw key data
        impl: PRNG implementation
        
    Returns:
        PRNG key array
    """

def key_impl(key: Array) -> str:
    """
    Get PRNG implementation name for key.
    
    Args:
        key: PRNG key
        
    Returns:
        Implementation name string
    """

Continuous Distributions

Sample from continuous probability distributions.

def uniform(
    key: Array, 
    shape=(), 
    dtype=float, 
    minval=0.0, 
    maxval=1.0
) -> Array:
    """
    Sample from uniform distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        minval: Minimum value (inclusive)
        maxval: Maximum value (exclusive)
        
    Returns:
        Random samples from uniform distribution
    """

def normal(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from standard normal (Gaussian) distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from N(0, 1)
    """

def multivariate_normal(
    key: Array, 
    mean: Array, 
    cov: Array, 
    shape=(), 
    dtype=float, 
    method='cholesky'
) -> Array:
    """
    Sample from multivariate normal distribution.
    
    Args:
        key: PRNG key
        mean: Mean vector
        cov: Covariance matrix
        shape: Batch shape
        dtype: Output data type
        method: Decomposition method ('cholesky', 'eigh', 'svd')
        
    Returns:
        Random samples from multivariate normal
    """

def truncated_normal(
    key: Array, 
    lower: float, 
    upper: float, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from truncated normal distribution.
    
    Args:
        key: PRNG key
        lower: Lower truncation bound
        upper: Upper truncation bound
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from truncated normal
    """

def beta(key: Array, a: Array, b: Array, shape=(), dtype=float) -> Array:
    """
    Sample from beta distribution.
    
    Args:
        key: PRNG key
        a: Alpha parameter (concentration)
        b: Beta parameter (concentration)
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Beta(a, b)
    """

def gamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
    """
    Sample from gamma distribution.
    
    Args:
        key: PRNG key
        a: Shape parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Gamma(a, 1)
    """

def exponential(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from exponential distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Exponential(1)
    """

def laplace(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Laplace distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Laplace(0, 1)
    """

def logistic(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from logistic distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Logistic(0, 1)
    """

def lognormal(key: Array, sigma=1.0, shape=(), dtype=float) -> Array:
    """
    Sample from log-normal distribution.
    
    Args:
        key: PRNG key
        sigma: Standard deviation of underlying normal
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from log-normal distribution
    """

def pareto(key: Array, b: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Pareto distribution.
    
    Args:
        key: PRNG key
        b: Shape parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Pareto(b, 1)
    """

def cauchy(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Cauchy distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Cauchy(0, 1)
    """

def double_sided_maxwell(
    key: Array, 
    loc: Array, 
    scale: Array, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from double-sided Maxwell distribution.
    
    Args:
        key: PRNG key
        loc: Location parameter
        scale: Scale parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from double-sided Maxwell
    """

def maxwell(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Maxwell distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Maxwell distribution
    """

def rayleigh(key: Array, scale=1.0, shape=(), dtype=float) -> Array:
    """
    Sample from Rayleigh distribution.
    
    Args:
        key: PRNG key
        scale: Scale parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Rayleigh(scale)
    """

def wald(key: Array, mean: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Wald (Inverse Gaussian) distribution.
    
    Args:
        key: PRNG key
        mean: Mean parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Wald distribution
    """

def weibull_min(
    key: Array, 
    concentration: Array, 
    scale=1.0, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from Weibull minimum distribution.
    
    Args:
        key: PRNG key
        concentration: Shape parameter
        scale: Scale parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Weibull minimum
    """

def gumbel(key: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Gumbel distribution.
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Gumbel(0, 1)
    """

def chisquare(key: Array, df: Array, shape=(), dtype=float) -> Array:
    """
    Sample from chi-square distribution.
    
    Args:
        key: PRNG key
        df: Degrees of freedom
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from chi-square(df)
    """

def dirichlet(
    key: Array, 
    alpha: Array, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from Dirichlet distribution.
    
    Args:
        key: PRNG key
        alpha: Concentration parameters
        shape: Batch shape
        dtype: Output data type
        
    Returns:
        Random samples from Dirichlet(alpha)
    """

def f(key: Array, dfnum: Array, dfden: Array, shape=(), dtype=float) -> Array:
    """
    Sample from F-distribution.
    
    Args:
        key: PRNG key
        dfnum: Numerator degrees of freedom
        dfden: Denominator degrees of freedom  
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from F-distribution
    """

def t(key: Array, df: Array, shape=(), dtype=float) -> Array:
    """
    Sample from Student's t-distribution.
    
    Args:
        key: PRNG key
        df: Degrees of freedom
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from t-distribution
    """

def triangular(
    key: Array, 
    left: Array, 
    mode: Array, 
    right: Array, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from triangular distribution.
    
    Args:
        key: PRNG key
        left: Left boundary
        mode: Mode (peak) value
        right: Right boundary
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from triangular distribution
    """

def generalized_normal(
    key: Array, 
    p: Array, 
    shape=(), 
    dtype=float
) -> Array:
    """
    Sample from generalized normal distribution.
    
    Args:
        key: PRNG key
        p: Shape parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from generalized normal
    """

def loggamma(key: Array, a: Array, shape=(), dtype=float) -> Array:
    """
    Sample log-gamma random variables.
    
    Args:
        key: PRNG key
        a: Shape parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from log-gamma distribution
    """

Discrete Distributions

Sample from discrete probability distributions.

def bernoulli(key: Array, p=0.5, shape=(), dtype=int) -> Array:
    """
    Sample from Bernoulli distribution.
    
    Args:
        key: PRNG key
        p: Success probability
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Bernoulli(p)
    """

def binomial(key: Array, n: Array, p: Array, shape=(), dtype=int) -> Array:
    """
    Sample from binomial distribution.
    
    Args:
        key: PRNG key
        n: Number of trials
        p: Success probability per trial
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Binomial(n, p)
    """

def categorical(
    key: Array, 
    logits: Array, 
    axis=-1, 
    shape=None
) -> Array:
    """
    Sample from categorical distribution.
    
    Args:
        key: PRNG key
        logits: Log-probability array
        axis: Axis over which to normalize
        shape: Output shape
        
    Returns:
        Random categorical indices
    """

def choice(
    key: Array, 
    a: int | Array, 
    shape=(), 
    replace=True, 
    p=None, 
    axis=0
) -> Array:
    """
    Random choice from array elements.
    
    Args:
        key: PRNG key
        a: Array to sample from or integer (range)
        shape: Output shape
        replace: Whether to sample with replacement
        p: Probabilities for each element
        axis: Axis to sample along
        
    Returns:
        Random samples from input array
    """

def geometric(key: Array, p: Array, shape=(), dtype=int) -> Array:
    """
    Sample from geometric distribution.
    
    Args:
        key: PRNG key
        p: Success probability
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Geometric(p)
    """

def poisson(key: Array, lam: Array, shape=(), dtype=int) -> Array:
    """
    Sample from Poisson distribution.
    
    Args:
        key: PRNG key
        lam: Rate parameter
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from Poisson(lam)
    """

def multinomial(
    key: Array, 
    n: Array, 
    pvals: Array, 
    shape=(), 
    dtype=int
) -> Array:
    """
    Sample from multinomial distribution.
    
    Args:
        key: PRNG key
        n: Number of trials
        pvals: Probability values for each category
        shape: Batch shape
        dtype: Output data type
        
    Returns:
        Random samples from Multinomial(n, pvals)
    """

def randint(
    key: Array, 
    minval: int, 
    maxval: int, 
    shape=(), 
    dtype=int
) -> Array:
    """
    Sample random integers from [minval, maxval).
    
    Args:
        key: PRNG key
        minval: Minimum value (inclusive)
        maxval: Maximum value (exclusive)
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random integers in specified range
    """

def rademacher(key: Array, shape=(), dtype=int) -> Array:
    """
    Sample from Rademacher distribution (±1 with equal probability).
    
    Args:
        key: PRNG key
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random samples from {-1, +1}
    """

Specialized Sampling

Special sampling functions for geometric shapes and structured sampling.

def ball(key: Array, d: int, p=2, shape=(), dtype=float) -> Array:
    """
    Sample uniformly from d-dimensional unit ball.
    
    Args:
        key: PRNG key
        d: Dimension of ball
        p: Norm type (default: 2 for Euclidean)
        shape: Batch shape
        dtype: Output data type
        
    Returns:
        Random samples from unit ball
    """

def orthogonal(key: Array, n: int, shape=(), dtype=float) -> Array:
    """
    Sample random orthogonal matrix.
    
    Args:
        key: PRNG key
        n: Matrix dimension
        shape: Batch shape
        dtype: Output data type
        
    Returns:
        Random orthogonal matrix of size (n, n)
    """

def permutation(key: Array, x: int | Array, axis=0, independent=False) -> Array:
    """
    Generate random permutation of array or integers.
    
    Args:
        key: PRNG key
        x: Array to permute or integer (range)
        axis: Axis to permute along
        independent: Whether to permute each batch element independently
        
    Returns:
        Randomly permuted array
    """

def bits(key: Array, width=64, shape=(), dtype=None) -> Array:
    """
    Generate random bits.
    
    Args:
        key: PRNG key
        width: Number of bits per sample
        shape: Output shape
        dtype: Output data type
        
    Returns:
        Random bit patterns
    """

Usage Examples

Common patterns for JAX random number generation:

import jax
import jax.numpy as jnp
import jax.random as jr

# Create and split keys
main_key = jr.key(42)
key1, key2, key3 = jr.split(main_key, 3)

# Basic sampling
samples = jr.normal(key1, (1000,))
random_ints = jr.randint(key2, 0, 10, (100,))

# Batch sampling with same key
batch_samples = jr.normal(key3, (32, 784))  # 32 samples of 784 dims

# Different keys for each batch element
keys = jr.split(main_key, 32)
independent_samples = jax.vmap(
    lambda k: jr.normal(k, (784,))
)(keys)

# Random choice and permutation
data = jnp.arange(100)
shuffled = jr.permutation(key1, data)
selected = jr.choice(key2, data, (10,), replace=False)

# Multivariate distributions
mean = jnp.zeros(5)
cov = jnp.eye(5) 
mv_samples = jr.multivariate_normal(key1, mean, cov, (1000,))

# Discrete distributions
coin_flips = jr.bernoulli(key1, 0.6, (100,))
dice_rolls = jr.categorical(key2, jnp.log(jnp.ones(6) / 6), (100,))

# Using in neural network initialization
def init_layer_weights(key, input_dim, output_dim):
    w_key, b_key = jr.split(key)
    # Xavier/Glorot initialization
    std = jnp.sqrt(2.0 / (input_dim + output_dim))
    weights = jr.normal(w_key, (input_dim, output_dim)) * std
    biases = jr.normal(b_key, (output_dim,)) * 0.01
    return weights, biases

# Stochastic gradient descent with random batching
def get_random_batch(key, data, batch_size):
    indices = jr.choice(key, len(data), (batch_size,), replace=False)
    return data[indices]

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json