Implementation of Gaussian processes in Python with support for AutoGrad, TensorFlow, PyTorch, and JAX
Advanced model organization using measures to manage collections of Gaussian processes. Measures provide centralized management of GP relationships, naming, conditioning operations, and sampling across multiple processes simultaneously.
Create and manage measures that serve as containers for collections of related Gaussian processes with shared operations and naming.
class Measure:
def __init__(self):
"""Initialize a new measure."""
ps: List[GP] # List of processes in the measure
means: LazyVector # Lazy vector of mean functions
kernels: LazyMatrix # Lazy matrix of kernel functions
default: ClassVar[Optional[Measure]] # Global default measureUse measures as context managers to temporarily set default measure for GP construction within a scope.
class Measure:
def __enter__(self) -> Measure:
"""Enter context manager, setting as default measure."""
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager, restoring previous default."""Manage GP names and retrieve GPs by name or names by GP for better organization and debugging.
class Measure:
def __getitem__(self, key):
"""
Get GP by name or get name by GP.
Parameters:
- key: GP name (str) or GP object
Returns:
- GP or str: GP object if key is name, name if key is GP
"""
def name(self, p, name):
"""
Assign name to a GP.
Parameters:
- p: Gaussian process to name
- name: Name to assign
"""Add new Gaussian processes to the measure with specified means and kernels, or manage existing processes.
class Measure:
def add_independent_gp(self, p, mean, kernel):
"""
Add independent GP to the measure.
Parameters:
- p: GP object to add
- mean: Mean function for the GP
- kernel: Kernel function for the GP
"""
def add_gp(self, mean, kernel, left_rule, right_rule=None):
"""
Add GP with custom kernel rules.
Parameters:
- mean: Mean function
- kernel: Kernel function
- left_rule: Left kernel rule
- right_rule: Optional right kernel rule
Returns:
- GP: The added process
"""
def __call__(self, p):
"""
Apply measure to GP or FDD.
Parameters:
- p: GP or FDD to apply measure to
Returns:
- Applied measure result
"""Perform operations between GPs within the same measure, including summation, multiplication, and transformations.
class Measure:
def sum(self, p_sum, p1, p2):
"""
Perform sum operation between two GPs.
Parameters:
- p_sum: Resulting sum GP
- p1: First GP operand
- p2: Second GP operand
"""
def mul(self, p_mul, p1, p2):
"""
Perform multiplication between two GPs.
Parameters:
- p_mul: Resulting product GP
- p1: First GP operand
- p2: Second GP operand
"""
def shift(self, p_shifted, p, shift):
"""
Shift GP inputs.
Parameters:
- p_shifted: Resulting shifted GP
- p: Source GP
- shift: Shift amount
"""
def stretch(self, p_stretched, p, stretch):
"""
Stretch GP inputs.
Parameters:
- p_stretched: Resulting stretched GP
- p: Source GP
- stretch: Stretch factor
"""
def select(self, p_selected, p, *dims):
"""
Select dimensions from GP.
Parameters:
- p_selected: Resulting GP
- p: Source GP
- *dims: Dimensions to select
"""
def transform(self, p_transformed, p, f):
"""
Transform GP inputs.
Parameters:
- p_transformed: Resulting transformed GP
- p: Source GP
- f: Transformation function
"""
def diff(self, p_diff, p, dim=0):
"""
Differentiate GP.
Parameters:
- p_diff: Resulting derivative GP
- p: Source GP
- dim: Dimension to differentiate
"""
def cross(self, p_cross, *ps):
"""
Create cross product of GPs.
Parameters:
- p_cross: Resulting cross product GP
- *ps: GPs to combine
"""Condition the entire measure on observations, creating posterior measures with updated beliefs across all processes.
class Measure:
def condition(self, obs):
"""
Condition measure on observations.
Parameters:
- obs: Observations object
Returns:
- Measure: Posterior measure
"""
def __or__(self, *args):
"""Shorthand for condition() using | operator."""Sample from multiple processes simultaneously, maintaining correlations and relationships between processes within the measure.
class Measure:
def sample(self, state, n, *fdds):
"""
Sample from multiple processes with random state.
Parameters:
- state: Random state for sampling
- n: Number of samples
- *fdds: FDDs to sample from
Returns:
- Tuple: (new_state, samples)
"""
def sample(self, n, *fdds):
"""
Sample from multiple processes without explicit state.
Parameters:
- n: Number of samples
- *fdds: FDDs to sample from
Returns:
- Samples from the processes
"""
def sample(self, *fdds):
"""
Sample single realization from processes.
Parameters:
- *fdds: FDDs to sample from
Returns:
- Single sample from the processes
"""Compute log probability densities for observations under the measure, useful for model comparison and hyperparameter optimization.
class Measure:
def logpdf(self, *pairs):
"""
Compute log probability density for observation pairs.
Parameters:
- *pairs: (FDD, values) pairs
Returns:
- Log probability density
"""
def logpdf(self, obs):
"""
Compute log probability density for observations.
Parameters:
- obs: Observations object
Returns:
- Log probability density
"""import stheno
import numpy as np
# Create measure
measure = stheno.Measure()
# Use as context manager
with measure:
gp1 = stheno.GP(kernel=stheno.EQ(), name="signal")
gp2 = stheno.GP(kernel=stheno.Matern52(), name="noise")
# Access GPs by name
signal_gp = measure["signal"]
noise_gp = measure["noise"]
print(f"Signal GP: {measure[signal_gp]}") # Get name from GP# Create measure for climate model
climate = stheno.Measure()
with climate:
# Temperature process
temp = stheno.GP(
kernel=stheno.EQ().stretch(2.0) * stheno.Matern52().stretch(0.5),
name="temperature"
)
# Humidity correlated with temperature
humidity = 0.8 * temp + stheno.GP(kernel=stheno.EQ(), name="humidity_residual")
climate.name(humidity, "humidity")
# Pressure with seasonal component
pressure = stheno.GP(kernel=stheno.EQ().stretch(10.0), name="pressure")
# Work with the model
x = np.linspace(0, 365, 100) # Days in year
temp_fdd = temp(x)
humidity_fdd = humidity(x)
pressure_fdd = pressure(x)# Generate observations
temp_obs = temp_fdd.sample()
humidity_obs = humidity_fdd.sample()
# Condition entire measure on all observations
posterior_climate = climate.condition(
stheno.Observations(temp_fdd, temp_obs),
stheno.Observations(humidity_fdd, humidity_obs)
)
# All processes now conditioned
posterior_temp = posterior_climate["temperature"]
posterior_humidity = posterior_climate["humidity"]
posterior_pressure = posterior_climate["pressure"]# Sample from multiple processes simultaneously
x_pred = np.linspace(0, 365, 50)
# Individual FDDs
temp_pred = posterior_temp(x_pred)
humidity_pred = posterior_humidity(x_pred)
pressure_pred = posterior_pressure(x_pred)
# Joint sampling maintains correlations
samples = climate.sample(5, temp_pred, humidity_pred, pressure_pred)
# Access individual process samples
temp_samples = samples[0] # First process samples
humidity_samples = samples[1] # Second process samples
pressure_samples = samples[2] # Third process samples# Create competing models
model1 = stheno.Measure()
model2 = stheno.Measure()
with model1:
gp1 = stheno.GP(kernel=stheno.EQ())
with model2:
gp2 = stheno.GP(kernel=stheno.Matern52())
# Test data
x_test = np.linspace(0, 1, 20)
y_test = np.sin(x_test) + 0.1 * np.random.randn(len(x_test))
# Compute log marginal likelihoods
fdd1 = gp1(x_test, noise=0.1)
fdd2 = gp2(x_test, noise=0.1)
logpdf1 = model1.logpdf(fdd1, y_test)
logpdf2 = model2.logpdf(fdd2, y_test)
print(f"Model 1 log likelihood: {logpdf1}")
print(f"Model 2 log likelihood: {logpdf2}")
print(f"Model {'1' if logpdf1 > logpdf2 else '2'} is preferred")# Create measure with custom operations
advanced_measure = stheno.Measure()
# Manually add GPs with custom kernel relationships
gp_base = stheno.GP()
gp_derived = stheno.GP()
# Add to measure with custom rules
advanced_measure.add_independent_gp(gp_base, stheno.ZeroMean(), stheno.EQ())
# Create derived process through measure operations
advanced_measure.sum(gp_derived, gp_base, gp_base) # gp_derived = 2 * gp_base
# Use derived relationships
x = np.linspace(0, 1, 50)
base_samples = gp_base(x).sample()
derived_samples = gp_derived(x).sample()
print(f"Derived should be ~2x base: {np.allclose(derived_samples, 2 * base_samples)}")Install with Tessl CLI
npx tessl i tessl/pypi-stheno