Implementation of Gaussian processes in Python with support for AutoGrad, TensorFlow, PyTorch, and JAX
npx @tessl/cli install tessl/pypi-stheno@1.4.0A comprehensive Python library for Gaussian process modeling that enables probabilistic machine learning with support for multiple backend frameworks including AutoGrad, TensorFlow, PyTorch, and JAX. Stheno provides a flexible and expressive API for constructing sophisticated Gaussian process models, including support for multi-output regression, sparse approximations, inducing points, custom kernels and means, batched computation, and hyperparameter optimization.
pip install sthenoimport sthenoFor specific backend support:
# Backend-specific imports (choose one)
from stheno.autograd import GP, EQ # AutoGrad backend
from stheno.tensorflow import GP, EQ # TensorFlow backend
from stheno.torch import GP, EQ # PyTorch backend
from stheno.jax import GP, EQ # JAX backendEach backend provides the same API but uses different computational frameworks for automatic differentiation and GPU acceleration.
Common objects and functions:
from stheno import GP, Measure, FDD, Normal, PseudoObs, Obs
from stheno import EQ, Matern52, Linear # Kernels from mlkernelsimport stheno
import numpy as np
# Create a Gaussian process with an exponential quadratic kernel
gp = stheno.GP(kernel=stheno.EQ())
# Generate sample data
x = np.linspace(0, 2, 10)
y = np.sin(x) + 0.1 * np.random.randn(len(x))
# Create finite-dimensional distribution at observation points
fdd = gp(x)
# Condition on observations to get posterior (using | operator)
posterior = gp | (fdd, y)
# Make predictions at new points
x_new = np.linspace(0, 2, 100)
pred_fdd = posterior(x_new)
# Get mean and credible bounds
mean, lower, upper = pred_fdd.marginal_credible_bounds()
# Sample from the posterior
samples = pred_fdd.sample(5)Stheno's architecture is built around several key concepts:
This design enables flexible GP model construction, efficient inference, and seamless integration with modern ML frameworks while maintaining mathematical rigor and computational efficiency.
Fundamental GP construction, evaluation, conditioning, and posterior inference. Includes creating GPs with custom kernels and means, evaluating at points to create finite-dimensional distributions, and conditioning on observations.
class GP:
def __init__(self, mean=None, kernel=None, *, measure=None, name=None): ...
def __call__(self, x, noise=None) -> FDD: ...
def condition(self, *args): ...
def __or__(self, *args): ... # Shorthand for conditionclass FDD(Normal):
def __init__(self, p, x, noise=None): ...
p: GP # Process of FDD
x: Any # Inputs
noise: Optional[Any] # Additive noiseMathematical operations for combining and transforming Gaussian processes, including addition, multiplication, differentiation, input transformations, and dimension selection.
def cross(*ps) -> GP: ... # Cartesian product of processes
class GP:
def __add__(self, other): ... # Addition
def __mul__(self, other): ... # Multiplication
def shift(self, shift): ... # Input shifting
def stretch(self, stretch): ... # Input stretching
def transform(self, f): ... # Input transformation
def select(self, *dims): ... # Dimension selection
def diff(self, dim=0): ... # DifferentiationGP Arithmetic and Transformations
Advanced model organization using measures to manage collections of GPs, naming, cross-referencing, and maintaining relationships between processes.
class Measure:
def __init__(self): ...
def add_independent_gp(self, p, mean, kernel): ...
def name(self, p, name): ...
def __call__(self, p): ...
def condition(self, obs): ...
def sample(self, *args): ...
def logpdf(self, *args): ...Structured observation handling including standard observations and sparse approximations (VFE, FITC, DTC) for scalable GP inference with inducing points.
class Observations:
def __init__(self, fdd, y): ...
def __init__(self, *pairs): ...
class PseudoObservations:
def __init__(self, u, fdd, y): ...
def elbo(self, measure): ...
method: str # "vfe", "fitc", or "dtc"
class PseudoObservationsFITC(PseudoObservations): ...
class PseudoObservationsDTC(PseudoObservations): ...Support for multi-output GPs using specialized kernel and mean functions that handle vector-valued processes and cross-covariances between outputs.
class MultiOutputKernel:
def __init__(self, measure, *ps): ...
measure: Measure
ps: Tuple[GP, ...]
class MultiOutputMean:
def __init__(self, measure, *ps): ...
def __call__(self, x): ...Probabilistic computation with Normal distributions, including marginal calculations, sampling, entropy, KL divergence, and Wasserstein distances.
class Normal(RandomVector):
def __init__(self, mean, var): ...
def __init__(self, var): ...
mean: Any # Mean vector
var: Any # Covariance matrix
def marginals(self): ...
def logpdf(self, x): ...
def sample(self, num=1, noise=None): ...
def entropy(self): ...
def kl(self, other): ...
def w2(self, other): ...Random Variables and Distributions
Efficient computation with lazy evaluation for vectors and matrices that build values on-demand using custom rules and caching.
class LazyVector:
def __init__(self): ...
def add_rule(self, indices, builder): ...
class LazyMatrix:
def __init__(self): ...
def add_rule(self, indices, builder): ...
def add_left_rule(self, i_left, indices, builder): ...
def add_right_rule(self, i_right, indices, builder): ...# Core types from stheno
class BreakingChangeWarning(UserWarning): ...
# Random object hierarchy
class Random: ...
class RandomProcess(Random): ...
class RandomVector(Random): ...
# Re-exported from dependencies
B = lab # Backend abstraction
matrix = matrix # Structured matricesStheno re-exports functionality from several key libraries:
B for array operations and linear algebra