0
# Random Variables and Distributions
1
2
Probabilistic computation with Normal distributions and random object hierarchies. This module provides the foundation for uncertainty quantification, sampling, and probabilistic inference in Gaussian process models.
3
4
## Capabilities
5
6
### Random Object Hierarchy
7
8
Base classes for random objects providing common arithmetic operations and the foundation for probabilistic modeling.
9
10
```python { .api }
11
class Random:
12
"""Base class for random objects."""
13
14
def __radd__(self, other):
15
"""Right addition (other + self)."""
16
17
def __rmul__(self, other):
18
"""Right multiplication (other * self)."""
19
20
def __neg__(self):
21
"""Negation (-self)."""
22
23
def __sub__(self, other):
24
"""Subtraction (self - other)."""
25
26
def __rsub__(self, other):
27
"""Right subtraction (other - self)."""
28
29
def __div__(self, other):
30
"""Division (self / other)."""
31
32
def __truediv__(self, other):
33
"""True division (self / other)."""
34
35
class RandomProcess(Random):
36
"""Base class for random processes."""
37
38
class RandomVector(Random):
39
"""Base class for random vectors."""
40
```
41
42
### Normal Distribution
43
44
Multivariate normal/Gaussian distribution with comprehensive functionality for probabilistic computation, sampling, and inference.
45
46
```python { .api }
47
class Normal(RandomVector):
48
def __init__(self, mean, var):
49
"""
50
Initialize Normal distribution with mean and variance.
51
52
Parameters:
53
- mean: Mean vector (column vector)
54
- var: Covariance matrix
55
"""
56
57
def __init__(self, var):
58
"""
59
Initialize Normal distribution with zero mean.
60
61
Parameters:
62
- var: Covariance matrix
63
"""
64
65
def __init__(self, mean_func, var_func, *, var_diag=None, mean_var=None, mean_var_diag=None):
66
"""
67
Initialize Normal distribution with lazy evaluation functions.
68
69
Parameters:
70
- mean_func: Function that returns mean when called
71
- var_func: Function that returns variance when called
72
- var_diag: Optional function for diagonal variance
73
- mean_var: Optional function returning (mean, var) tuple
74
- mean_var_diag: Optional function returning (mean, var_diag) tuple
75
"""
76
77
def __init__(self, var_func, **kw_args):
78
"""Initialize with zero mean function and variance function."""
79
```
80
81
### Normal Distribution Properties
82
83
Access distributional properties including moments, dimensions, and data types.
84
85
```python { .api }
86
class Normal:
87
@property
88
def mean(self):
89
"""column vector: Mean of the distribution."""
90
91
@property
92
def var(self):
93
"""matrix: Covariance matrix of the distribution."""
94
95
@property
96
def var_diag(self):
97
"""vector: Diagonal of the covariance matrix."""
98
99
@property
100
def mean_var(self):
101
"""tuple[column vector, matrix]: Mean and covariance tuple."""
102
103
@property
104
def dtype(self):
105
"""dtype: Data type of the distribution."""
106
107
@property
108
def dim(self):
109
"""int: Dimensionality of the distribution."""
110
111
@property
112
def m2(self):
113
"""matrix: Second moment matrix."""
114
115
@property
116
def mean_is_zero(self):
117
"""bool: Whether the mean is identically zero."""
118
```
119
120
### Marginal Operations
121
122
Compute marginal statistics and credible intervals for individual components of the multivariate distribution.
123
124
```python { .api }
125
class Normal:
126
def marginals(self):
127
"""
128
Get marginal means and variances.
129
130
Returns:
131
- tuple: (marginal_means, marginal_variances)
132
"""
133
134
def marginal_credible_bounds(self):
135
"""
136
Get marginal 95% central credible interval bounds.
137
138
Returns:
139
- tuple: (means, lower_bounds, upper_bounds)
140
"""
141
142
def diagonalise(self):
143
"""
144
Create diagonal version by setting correlations to zero.
145
146
Returns:
147
- Normal: Diagonal version of the distribution
148
"""
149
```
150
151
### Probability Computations
152
153
Compute probability densities, entropies, and divergences for model evaluation and comparison.
154
155
```python { .api }
156
class Normal:
157
def logpdf(self, x):
158
"""
159
Compute log probability density function.
160
161
Parameters:
162
- x: Values to evaluate PDF at
163
164
Returns:
165
- Log probability density (scalar or array)
166
"""
167
168
def entropy(self):
169
"""
170
Compute differential entropy of the distribution.
171
172
Returns:
173
- scalar: Entropy value
174
"""
175
176
def kl(self, other):
177
"""
178
Compute KL divergence with respect to another Normal distribution.
179
180
Parameters:
181
- other: Other Normal distribution
182
183
Returns:
184
- scalar: KL divergence D_KL(self || other)
185
"""
186
187
def w2(self, other):
188
"""
189
Compute 2-Wasserstein distance with another Normal distribution.
190
191
Parameters:
192
- other: Other Normal distribution
193
194
Returns:
195
- scalar: 2-Wasserstein distance
196
"""
197
```
198
199
### Sampling Operations
200
201
Generate samples from the Normal distribution with optional noise addition and explicit random state management.
202
203
```python { .api }
204
class Normal:
205
def sample(self, state, num=1, noise=None):
206
"""
207
Sample from distribution with explicit random state.
208
209
Parameters:
210
- state: Random state for sampling
211
- num: Number of samples to generate
212
- noise: Optional additional noise variance
213
214
Returns:
215
- tuple: (new_state, samples)
216
"""
217
218
def sample(self, num=1, noise=None):
219
"""
220
Sample from distribution using global random state.
221
222
Parameters:
223
- num: Number of samples to generate
224
- noise: Optional additional noise variance
225
226
Returns:
227
- tensor: Samples as rank-2 column vectors
228
"""
229
```
230
231
### Arithmetic Operations
232
233
Perform arithmetic operations with Normal distributions and scalars while maintaining distributional properties.
234
235
```python { .api }
236
class Normal:
237
def __add__(self, other):
238
"""
239
Add scalar or another Normal distribution.
240
241
Parameters:
242
- other: Scalar or Normal distribution
243
244
Returns:
245
- Normal: Resulting distribution
246
"""
247
248
def __mul__(self, other):
249
"""
250
Multiply by scalar.
251
252
Parameters:
253
- other: Scalar multiplier
254
255
Returns:
256
- Normal: Scaled distribution
257
"""
258
259
def lmatmul(self, other):
260
"""
261
Left matrix multiplication (other @ self).
262
263
Parameters:
264
- other: Matrix to multiply with
265
266
Returns:
267
- Normal: Transformed distribution
268
"""
269
270
def rmatmul(self, other):
271
"""
272
Right matrix multiplication (self @ other).
273
274
Parameters:
275
- other: Matrix to multiply with
276
277
Returns:
278
- Normal: Transformed distribution
279
"""
280
```
281
282
### Backend Operations
283
284
Low-level operations for dtype handling and casting across different backends.
285
286
```python { .api }
287
def dtype(dist):
288
"""
289
Get data type of Normal distribution.
290
291
Parameters:
292
- dist: Normal distribution
293
294
Returns:
295
- Data type
296
"""
297
298
def cast(dtype, dist):
299
"""
300
Cast Normal distribution to specified data type.
301
302
Parameters:
303
- dtype: Target data type
304
- dist: Normal distribution to cast
305
306
Returns:
307
- Normal: Distribution with specified dtype
308
"""
309
```
310
311
## Usage Examples
312
313
### Basic Normal Distribution Operations
314
315
```python
316
import stheno
317
import numpy as np
318
319
# Create simple Normal distribution
320
mean = np.array([[1.0], [2.0]]) # Column vector
321
cov = np.array([[1.0, 0.5], [0.5, 2.0]])
322
normal = stheno.Normal(mean, cov)
323
324
# Access properties
325
print(f"Mean: {normal.mean.flatten()}")
326
print(f"Variance diagonal: {normal.var_diag}")
327
print(f"Dimensionality: {normal.dim}")
328
print(f"Data type: {normal.dtype}")
329
330
# Compute marginals
331
marginal_means, marginal_vars = normal.marginals()
332
print(f"Marginal means: {marginal_means}")
333
print(f"Marginal variances: {marginal_vars}")
334
```
335
336
### Sampling and Probability Computations
337
338
```python
339
# Sample from distribution
340
samples = normal.sample(num=100)
341
print(f"Sample shape: {samples.shape}") # Should be (2, 100)
342
343
# Compute log probability density
344
test_points = np.array([[1.2], [1.8]])
345
logpdf = normal.logpdf(test_points)
346
print(f"Log PDF: {logpdf}")
347
348
# Compute entropy
349
entropy = normal.entropy()
350
print(f"Entropy: {entropy:.3f}")
351
```
352
353
### Credible Intervals
354
355
```python
356
# Get marginal credible bounds
357
means, lower, upper = normal.marginal_credible_bounds()
358
print(f"95% credible intervals:")
359
print(f"Dimension 0: [{lower[0]:.3f}, {upper[0]:.3f}]")
360
print(f"Dimension 1: [{lower[1]:.3f}, {upper[1]:.3f}]")
361
362
# Create diagonal version
363
diagonal_normal = normal.diagonalise()
364
diag_samples = diagonal_normal.sample(num=50)
365
```
366
367
### Distribution Arithmetic
368
369
```python
370
# Create two Normal distributions
371
normal1 = stheno.Normal(np.array([[1.0], [0.0]]), np.eye(2))
372
normal2 = stheno.Normal(np.array([[0.0], [1.0]]), 0.5 * np.eye(2))
373
374
# Addition of distributions
375
sum_normal = normal1 + normal2
376
print(f"Sum mean: {sum_normal.mean.flatten()}")
377
print(f"Sum variance diagonal: {sum_normal.var_diag}")
378
379
# Scale distribution
380
scaled_normal = 2.0 * normal1
381
print(f"Scaled mean: {scaled_normal.mean.flatten()}")
382
print(f"Scaled variance diagonal: {scaled_normal.var_diag}")
383
384
# Add constant
385
shifted_normal = normal1 + 3.0
386
print(f"Shifted mean: {shifted_normal.mean.flatten()}")
387
```
388
389
### Linear Transformations
390
391
```python
392
# Create transformation matrix
393
A = np.array([[2.0, 1.0], [0.0, 3.0]])
394
395
# Left multiplication: A @ X
396
transformed = normal.lmatmul(A)
397
print(f"Transformed mean: {transformed.mean.flatten()}")
398
399
# Right multiplication: X @ A.T (for row vectors)
400
B = np.array([[1.0, 0.5]])
401
right_transformed = normal.rmatmul(B.T)
402
print(f"Right transformed shape: {right_transformed.mean.shape}")
403
```
404
405
### Missing Data Handling
406
407
```python
408
# Create data with missing values (NaN)
409
x_with_missing = np.array([[1.0], [np.nan], [2.0]])
410
411
# Normal distribution handles missing data automatically
412
logpdf_missing = normal.logpdf(x_with_missing)
413
print(f"Log PDF with missing data: {logpdf_missing}")
414
```
415
416
### Distribution Comparison
417
418
```python
419
# Create two competing distributions
420
true_dist = stheno.Normal(np.zeros((2, 1)), np.eye(2))
421
approx_dist = stheno.Normal(np.array([[0.1], [0.1]]), 1.1 * np.eye(2))
422
423
# Compute KL divergence
424
kl_div = true_dist.kl(approx_dist)
425
print(f"KL divergence: {kl_div:.3f}")
426
427
# Compute Wasserstein distance
428
w2_dist = true_dist.w2(approx_dist)
429
print(f"2-Wasserstein distance: {w2_dist:.3f}")
430
431
# Compare with reversed order
432
kl_rev = approx_dist.kl(true_dist)
433
print(f"Reverse KL divergence: {kl_rev:.3f}")
434
```
435
436
### Lazy Evaluation with Functions
437
438
```python
439
# Create Normal with lazy evaluation
440
def mean_func():
441
print("Computing mean...")
442
return np.array([[1.0], [2.0]])
443
444
def var_func():
445
print("Computing variance...")
446
return np.array([[2.0, 0.3], [0.3, 1.5]])
447
448
def var_diag_func():
449
print("Computing variance diagonal...")
450
return np.array([2.0, 1.5])
451
452
lazy_normal = stheno.Normal(
453
mean_func,
454
var_func,
455
var_diag=var_diag_func
456
)
457
458
# Properties are computed on-demand
459
print("Accessing mean:")
460
mean = lazy_normal.mean # Triggers mean computation
461
462
print("Accessing variance diagonal:")
463
var_diag = lazy_normal.var_diag # Uses var_diag_func, not var_func
464
```
465
466
### Working with Different Backends
467
468
```python
469
# The Normal class works with different numerical backends
470
# through the LAB abstraction layer
471
472
# Example with numpy arrays (default)
473
numpy_normal = stheno.Normal(
474
np.array([[1.0], [2.0]]),
475
np.array([[1.0, 0.2], [0.2, 1.0]])
476
)
477
478
# When using backend-specific modules, tensors are handled automatically
479
# import stheno.torch # Would enable PyTorch tensors
480
# import stheno.jax # Would enable JAX arrays
481
# etc.
482
483
print(f"Backend dtype: {numpy_normal.dtype}")
484
```
485
486
### Random State Management
487
488
```python
489
# Explicit random state control for reproducible sampling
490
import lab as B
491
492
state = B.create_random_state(B.default_dtype, seed=42)
493
494
# Sample with explicit state
495
state, sample1 = normal.sample(state, num=10)
496
state, sample2 = normal.sample(state, num=10)
497
498
print(f"Sample 1 shape: {sample1.shape}")
499
print(f"Sample 2 shape: {sample2.shape}")
500
501
# Samples are different but reproducible with same seed
502
print(f"Samples differ: {not np.allclose(sample1, sample2)}")
503
```