Implementation of Gaussian processes in Python with support for AutoGrad, TensorFlow, PyTorch, and JAX
Probabilistic computation with Normal distributions and random object hierarchies. This module provides the foundation for uncertainty quantification, sampling, and probabilistic inference in Gaussian process models.
Base classes for random objects providing common arithmetic operations and the foundation for probabilistic modeling.
class Random:
"""Base class for random objects."""
def __radd__(self, other):
"""Right addition (other + self)."""
def __rmul__(self, other):
"""Right multiplication (other * self)."""
def __neg__(self):
"""Negation (-self)."""
def __sub__(self, other):
"""Subtraction (self - other)."""
def __rsub__(self, other):
"""Right subtraction (other - self)."""
def __div__(self, other):
"""Division (self / other)."""
def __truediv__(self, other):
"""True division (self / other)."""
class RandomProcess(Random):
"""Base class for random processes."""
class RandomVector(Random):
"""Base class for random vectors."""Multivariate normal/Gaussian distribution with comprehensive functionality for probabilistic computation, sampling, and inference.
class Normal(RandomVector):
def __init__(self, mean, var):
"""
Initialize Normal distribution with mean and variance.
Parameters:
- mean: Mean vector (column vector)
- var: Covariance matrix
"""
def __init__(self, var):
"""
Initialize Normal distribution with zero mean.
Parameters:
- var: Covariance matrix
"""
def __init__(self, mean_func, var_func, *, var_diag=None, mean_var=None, mean_var_diag=None):
"""
Initialize Normal distribution with lazy evaluation functions.
Parameters:
- mean_func: Function that returns mean when called
- var_func: Function that returns variance when called
- var_diag: Optional function for diagonal variance
- mean_var: Optional function returning (mean, var) tuple
- mean_var_diag: Optional function returning (mean, var_diag) tuple
"""
def __init__(self, var_func, **kw_args):
"""Initialize with zero mean function and variance function."""Access distributional properties including moments, dimensions, and data types.
class Normal:
@property
def mean(self):
"""column vector: Mean of the distribution."""
@property
def var(self):
"""matrix: Covariance matrix of the distribution."""
@property
def var_diag(self):
"""vector: Diagonal of the covariance matrix."""
@property
def mean_var(self):
"""tuple[column vector, matrix]: Mean and covariance tuple."""
@property
def dtype(self):
"""dtype: Data type of the distribution."""
@property
def dim(self):
"""int: Dimensionality of the distribution."""
@property
def m2(self):
"""matrix: Second moment matrix."""
@property
def mean_is_zero(self):
"""bool: Whether the mean is identically zero."""Compute marginal statistics and credible intervals for individual components of the multivariate distribution.
class Normal:
def marginals(self):
"""
Get marginal means and variances.
Returns:
- tuple: (marginal_means, marginal_variances)
"""
def marginal_credible_bounds(self):
"""
Get marginal 95% central credible interval bounds.
Returns:
- tuple: (means, lower_bounds, upper_bounds)
"""
def diagonalise(self):
"""
Create diagonal version by setting correlations to zero.
Returns:
- Normal: Diagonal version of the distribution
"""Compute probability densities, entropies, and divergences for model evaluation and comparison.
class Normal:
def logpdf(self, x):
"""
Compute log probability density function.
Parameters:
- x: Values to evaluate PDF at
Returns:
- Log probability density (scalar or array)
"""
def entropy(self):
"""
Compute differential entropy of the distribution.
Returns:
- scalar: Entropy value
"""
def kl(self, other):
"""
Compute KL divergence with respect to another Normal distribution.
Parameters:
- other: Other Normal distribution
Returns:
- scalar: KL divergence D_KL(self || other)
"""
def w2(self, other):
"""
Compute 2-Wasserstein distance with another Normal distribution.
Parameters:
- other: Other Normal distribution
Returns:
- scalar: 2-Wasserstein distance
"""Generate samples from the Normal distribution with optional noise addition and explicit random state management.
class Normal:
def sample(self, state, num=1, noise=None):
"""
Sample from distribution with explicit random state.
Parameters:
- state: Random state for sampling
- num: Number of samples to generate
- noise: Optional additional noise variance
Returns:
- tuple: (new_state, samples)
"""
def sample(self, num=1, noise=None):
"""
Sample from distribution using global random state.
Parameters:
- num: Number of samples to generate
- noise: Optional additional noise variance
Returns:
- tensor: Samples as rank-2 column vectors
"""Perform arithmetic operations with Normal distributions and scalars while maintaining distributional properties.
class Normal:
def __add__(self, other):
"""
Add scalar or another Normal distribution.
Parameters:
- other: Scalar or Normal distribution
Returns:
- Normal: Resulting distribution
"""
def __mul__(self, other):
"""
Multiply by scalar.
Parameters:
- other: Scalar multiplier
Returns:
- Normal: Scaled distribution
"""
def lmatmul(self, other):
"""
Left matrix multiplication (other @ self).
Parameters:
- other: Matrix to multiply with
Returns:
- Normal: Transformed distribution
"""
def rmatmul(self, other):
"""
Right matrix multiplication (self @ other).
Parameters:
- other: Matrix to multiply with
Returns:
- Normal: Transformed distribution
"""Low-level operations for dtype handling and casting across different backends.
def dtype(dist):
"""
Get data type of Normal distribution.
Parameters:
- dist: Normal distribution
Returns:
- Data type
"""
def cast(dtype, dist):
"""
Cast Normal distribution to specified data type.
Parameters:
- dtype: Target data type
- dist: Normal distribution to cast
Returns:
- Normal: Distribution with specified dtype
"""import stheno
import numpy as np
# Create simple Normal distribution
mean = np.array([[1.0], [2.0]]) # Column vector
cov = np.array([[1.0, 0.5], [0.5, 2.0]])
normal = stheno.Normal(mean, cov)
# Access properties
print(f"Mean: {normal.mean.flatten()}")
print(f"Variance diagonal: {normal.var_diag}")
print(f"Dimensionality: {normal.dim}")
print(f"Data type: {normal.dtype}")
# Compute marginals
marginal_means, marginal_vars = normal.marginals()
print(f"Marginal means: {marginal_means}")
print(f"Marginal variances: {marginal_vars}")# Sample from distribution
samples = normal.sample(num=100)
print(f"Sample shape: {samples.shape}") # Should be (2, 100)
# Compute log probability density
test_points = np.array([[1.2], [1.8]])
logpdf = normal.logpdf(test_points)
print(f"Log PDF: {logpdf}")
# Compute entropy
entropy = normal.entropy()
print(f"Entropy: {entropy:.3f}")# Get marginal credible bounds
means, lower, upper = normal.marginal_credible_bounds()
print(f"95% credible intervals:")
print(f"Dimension 0: [{lower[0]:.3f}, {upper[0]:.3f}]")
print(f"Dimension 1: [{lower[1]:.3f}, {upper[1]:.3f}]")
# Create diagonal version
diagonal_normal = normal.diagonalise()
diag_samples = diagonal_normal.sample(num=50)# Create two Normal distributions
normal1 = stheno.Normal(np.array([[1.0], [0.0]]), np.eye(2))
normal2 = stheno.Normal(np.array([[0.0], [1.0]]), 0.5 * np.eye(2))
# Addition of distributions
sum_normal = normal1 + normal2
print(f"Sum mean: {sum_normal.mean.flatten()}")
print(f"Sum variance diagonal: {sum_normal.var_diag}")
# Scale distribution
scaled_normal = 2.0 * normal1
print(f"Scaled mean: {scaled_normal.mean.flatten()}")
print(f"Scaled variance diagonal: {scaled_normal.var_diag}")
# Add constant
shifted_normal = normal1 + 3.0
print(f"Shifted mean: {shifted_normal.mean.flatten()}")# Create transformation matrix
A = np.array([[2.0, 1.0], [0.0, 3.0]])
# Left multiplication: A @ X
transformed = normal.lmatmul(A)
print(f"Transformed mean: {transformed.mean.flatten()}")
# Right multiplication: X @ A.T (for row vectors)
B = np.array([[1.0, 0.5]])
right_transformed = normal.rmatmul(B.T)
print(f"Right transformed shape: {right_transformed.mean.shape}")# Create data with missing values (NaN)
x_with_missing = np.array([[1.0], [np.nan], [2.0]])
# Normal distribution handles missing data automatically
logpdf_missing = normal.logpdf(x_with_missing)
print(f"Log PDF with missing data: {logpdf_missing}")# Create two competing distributions
true_dist = stheno.Normal(np.zeros((2, 1)), np.eye(2))
approx_dist = stheno.Normal(np.array([[0.1], [0.1]]), 1.1 * np.eye(2))
# Compute KL divergence
kl_div = true_dist.kl(approx_dist)
print(f"KL divergence: {kl_div:.3f}")
# Compute Wasserstein distance
w2_dist = true_dist.w2(approx_dist)
print(f"2-Wasserstein distance: {w2_dist:.3f}")
# Compare with reversed order
kl_rev = approx_dist.kl(true_dist)
print(f"Reverse KL divergence: {kl_rev:.3f}")# Create Normal with lazy evaluation
def mean_func():
print("Computing mean...")
return np.array([[1.0], [2.0]])
def var_func():
print("Computing variance...")
return np.array([[2.0, 0.3], [0.3, 1.5]])
def var_diag_func():
print("Computing variance diagonal...")
return np.array([2.0, 1.5])
lazy_normal = stheno.Normal(
mean_func,
var_func,
var_diag=var_diag_func
)
# Properties are computed on-demand
print("Accessing mean:")
mean = lazy_normal.mean # Triggers mean computation
print("Accessing variance diagonal:")
var_diag = lazy_normal.var_diag # Uses var_diag_func, not var_func# The Normal class works with different numerical backends
# through the LAB abstraction layer
# Example with numpy arrays (default)
numpy_normal = stheno.Normal(
np.array([[1.0], [2.0]]),
np.array([[1.0, 0.2], [0.2, 1.0]])
)
# When using backend-specific modules, tensors are handled automatically
# import stheno.torch # Would enable PyTorch tensors
# import stheno.jax # Would enable JAX arrays
# etc.
print(f"Backend dtype: {numpy_normal.dtype}")# Explicit random state control for reproducible sampling
import lab as B
state = B.create_random_state(B.default_dtype, seed=42)
# Sample with explicit state
state, sample1 = normal.sample(state, num=10)
state, sample2 = normal.sample(state, num=10)
print(f"Sample 1 shape: {sample1.shape}")
print(f"Sample 2 shape: {sample2.shape}")
# Samples are different but reproducible with same seed
print(f"Samples differ: {not np.allclose(sample1, sample2)}")Install with Tessl CLI
npx tessl i tessl/pypi-stheno